Update runners.py

This commit is contained in:
ClumsyLulz 2024-08-07 19:54:48 -04:00 committed by GitHub
parent 58690bf060
commit 92f7bbf073
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,21 +17,15 @@ import bisect
import functools import functools
import logging import logging
import math import math
import re import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import sentencepiece
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional, Tuple from typing import Any, Callable, NamedTuple, Optional, Tuple
from jax.experimental import pjit
import haiku as hk
import jax
import jax.experimental.pjit as pjit
import jax.numpy as jnp
import numpy as np
import sentencepiece
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax.typing import ArrayLike
import checkpoint as xai_checkpoint
from model import ( from model import (
LanguageModelConfig, LanguageModelConfig,
LanguageModelOutput, LanguageModelOutput,
@ -40,49 +34,43 @@ from model import (
Memory, Memory,
KVMemory, KVMemory,
) )
import checkpoint as xai_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
rank_logger = logging.getLogger("rank") rank_logger = logging.getLogger("rank")
rank_logger.setLevel(logging.INFO)
TOP_K = 8 TOP_K = 8
class SampleSettings(NamedTuple): class SampleSettings(NamedTuple):
temperature: ArrayLike temperature: jax.Array
nucleus_p: ArrayLike nucleus_p: jax.Array
mask: ArrayLike mask: jax.Array
# Whether a given batch element is actively used. [B] active: jax.Array
active: ArrayLike
class SampleOutput(NamedTuple): class SampleOutput(NamedTuple):
token_id: ArrayLike token_id: jax.Array
prob: ArrayLike prob: jax.Array
top_k_token_ids: ArrayLike top_k_token_ids: jax.Array
top_k_probs: ArrayLike top_k_probs: jax.Array
def insert_slice(memory: Memory, slice: Memory, length: int, i: int) -> Memory:
def insert_slice(memory: Memory, slice, length, i):
slice = Memory( slice = Memory(
layers=[ layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length])) KVMemory(layer.k, layer.v, step=jnp.array([length]))
for layer in slice.layers for layer in slice.layers
], ],
) )
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
memory, slice) memory, slice)
def pad_to_size(x: jnp.ndarray, size: int) -> jnp.ndarray:
def pad_to_size(x, size):
if x.shape[0] > size: if x.shape[0] > size:
# Left truncate if the context is too long.
x = x[-size:] x = x[-size:]
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0) return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
"""Performs nucleus filtering on logits."""
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits) sorted_probs = jax.nn.softmax(sorted_logits)
@ -92,36 +80,27 @@ def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
) )
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,) assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
mask = logits >= threshold_largest_logits mask = logits >= threshold_largest_logits
# Set unused logits to -inf.
logits = jnp.where(mask, logits, -1e10) logits = jnp.where(mask, logits, -1e10)
return logits return logits
def sample_token( def sample_token(
rngs: jax.random.PRNGKey, rngs: jax.random.PRNGKey,
lm_outputs: LanguageModelOutput, lm_outputs: LanguageModelOutput,
settings: SampleSettings, settings: SampleSettings,
) -> SampleOutput: ) -> SampleOutput:
# Expand the settings shape to match the logit shape.
settings = SampleSettings( settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. temperature=jnp.expand_dims(settings.temperature, (1, 2)),
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1]. nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)),
mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V]. mask=jnp.expand_dims(settings.mask, 1),
active=settings.active, # [B]. active=settings.active,
) )
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype) logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
# Mask out all disallowed tokens by assigning them a near-zero probability.
logits = jnp.where(settings.mask, logits, -1e10) logits = jnp.where(settings.mask, logits, -1e10)
# Mask out all tokens that don't fall into the p-th percentile.
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype)) logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
new_token = jax.vmap(jax.random.categorical)(rngs, logits) new_token = jax.vmap(jax.random.categorical)(rngs, logits)
probabilities = jax.nn.softmax(logits) probabilities = jax.nn.softmax(logits)
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2) token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
token_prob = jnp.squeeze(token_prob, 1) token_prob = jnp.squeeze(token_prob, 1)
# Gather the top-k tokens and probabilities.
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K) top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
top_k_probs = jnp.squeeze(top_k_probs, 1) top_k_probs = jnp.squeeze(top_k_probs, 1)
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1) top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
@ -132,19 +111,14 @@ def sample_token(
top_k_probs, top_k_probs,
) )
@dataclass @dataclass
class ModelRunner: class ModelRunner:
model: LanguageModelConfig model: LanguageModelConfig
bs_per_device: float = 2.0 bs_per_device: float = 2.0
load_rename_rules: Optional[list[tuple[str, str]]] = None load_rename_rules: Optional[list[tuple[str, str]]] = None
load_exclude_rules: Optional[list[str]] = None load_exclude_rules: Optional[list[str]] = None
rng_seed: int = 42
rng_seed: int = 42 # Initial rng seed.
transform_forward: bool = False transform_forward: bool = False
checkpoint_path: str = "" checkpoint_path: str = ""
def make_forward_fn(self, mesh: Any): def make_forward_fn(self, mesh: Any):
@ -159,20 +133,15 @@ class ModelRunner:
def initialize( def initialize(
self, self,
init_data, init_data,
local_mesh_config: tuple[int, int], local_mesh_config: Tuple[int, int],
between_hosts_config: tuple[int, int], between_hosts_config: Tuple[int, int],
): ):
num_replicas = math.prod(between_hosts_config) num_replicas = math.prod(between_hosts_config)
self.model.initialize() self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16 self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices()) num_local_gpus = len(jax.local_devices())
# Calculate the global batch size from the local batch size.
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas) self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
# Calculate the batch size per host from the global batch size.
self.local_batch_size = self.batch_size // jax.process_count() self.local_batch_size = self.batch_size // jax.process_count()
self.local_mesh_config = local_mesh_config self.local_mesh_config = local_mesh_config
self.between_hosts_config = between_hosts_config self.between_hosts_config = between_hosts_config
rank_logger.info( rank_logger.info(
@ -181,7 +150,6 @@ class ModelRunner:
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
self.forward = self.make_forward_fn(mesh=self.mesh) self.forward = self.make_forward_fn(mesh=self.mesh)
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0]) self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
self.eval_forward = self.make_forward_fn(mesh=self.mesh) self.eval_forward = self.make_forward_fn(mesh=self.mesh)
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0]) self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
@ -234,7 +202,6 @@ class ModelRunner:
assert self.transform_forward assert self.transform_forward
state_shapes = jax.eval_shape(self.init_fn, rng, init_data) state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
init_state = None init_state = None
state = xai_checkpoint.restore( state = xai_checkpoint.restore(
checkpoint_path=self.checkpoint_path, checkpoint_path=self.checkpoint_path,
state_shapes=state_shapes, state_shapes=state_shapes,
@ -244,11 +211,9 @@ class ModelRunner:
init_state=init_state, init_state=init_state,
params_only=True, params_only=True,
) )
del init_state del init_state
return state return state
@dataclass @dataclass
class Request: class Request:
prompt: str prompt: str
@ -257,18 +222,17 @@ class Request:
rng_seed: int rng_seed: int
max_len: int max_len: int
@dataclass @dataclass
class InferenceRunner: class InferenceRunner:
name: str name: str
runner: Any runner: ModelRunner
load: str load: str
tokenizer_path: str = "/tmp/xai_data/tokenizer.model" tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
local_mesh_config: Tuple[int, int] = (1, 1) local_mesh_config: Tuple[int, int] = (1, 1)
between_hosts_config: Tuple[int, int] = (1, 1) between_hosts_config: Tuple[int, int] = (1, 1)
pad_sizes: tuple[int] = (1024,) pad_sizes: Tuple[int] = (1024,)
def get_pad_bucket(self, size): def get_pad_bucket(self, size: int) -> int:
i = bisect.bisect_left(self.pad_sizes, size) i = bisect.bisect_left(self.pad_sizes, size)
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
@ -276,330 +240,70 @@ class InferenceRunner:
runner = self.runner runner = self.runner
self.runner.transform_forward = True self.runner.transform_forward = True
dummy_data = dict( dummy_data = dict(
inputs=np.zeros((1, 256), dtype=np.int32), inputs=np.zeros((1, self.get_pad_bucket(512)), dtype=np.int32),
targets=np.zeros((1, 256), dtype=np.int32),
) )
runner.initialize( state = runner.load_or_init(
dummy_data, dummy_data,
local_mesh_config=self.local_mesh_config, from_checkpoint=False,
between_hosts_config=self.between_hosts_config,
) )
runner.params = state.params
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path) self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
def text_to_token_ids(text):
ids = self.tokenizer.encode(text, out_type=int)
return ids
max_len = runner.model.sequence_len self.text_to_token_ids = text_to_token_ids
self.vocab_size = self.runner.model.vocab_size def predict(self, request: Request) -> str:
rng = jax.random.PRNGKey(request.rng_seed)
token_ids = self.text_to_token_ids(request.prompt)
rng, gen_rng = jax.random.split(rng)
params = runner.load_or_init(dummy_data) inputs = np.array(token_ids, dtype=np.int32)[np.newaxis, :]
self.params = params
def pad_to_max_len(x): token_ids = jnp.array(inputs)
if len(x.shape) > 1: state = self.runner.params
pad_width = max_len - x.shape[1]
return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])
else:
return x
@functools.lru_cache
def lm():
return runner.model.make(mesh=runner.mesh)
def hk_forward(
tokens,
memory=None,
length=None,
active=None,
) -> LanguageModelOutput:
if memory is not None:
assert active is not None
layers = []
for l in memory.layers:
# Reset steps to 0 for inactive requests to avoid unnecessary computations.
step = jnp.where(active, l.step, jnp.zeros_like(l.step))
layers.append(l._replace(step=step))
memory = memory._replace(layers=layers)
return lm()(tokens, memory, length=length)
def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):
rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)
lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)
sample_result = sample_token(rngs_, lm_outputs, settings)
return rngs, sample_result, lm_outputs.model_state
def hk_new_memory(batch_size, sequence_len):
return lm().init_memory(batch_size, sequence_len)
def hk_prefill_memory(
rngs,
memory,
settings,
last_output,
prompt,
length,
rng_seed,
new_settings,
i,
):
rng = jax.random.PRNGKey(seed=rng_seed)
rng, rng_ = jax.random.split(rng)
# Allocate new memory for this sample. The memory length is equal to the length of the
# prompt.
slice = hk_new_memory(1, prompt.shape[0])
# Move the settings for this individual batch entry into the joint settings tensor.
settings = jax.tree_map(
lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),
settings,
new_settings,
)
# Get the settings for the batch entry from the joint settings tensor.
settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)
# Process the first n-1 tokens of the prompt.
lm_outputs = hk_forward(
jnp.expand_dims(prompt, 0),
memory=slice,
length=jnp.expand_dims(length, 0),
active=settings_slice.active,
)
# The forward pass doesn't correctly set the `step` counter inside the memory. Manually
# override it so `hk_forward` uses the correct context length in the next call.
slice = lm_outputs.model_state
slice = slice._replace(
layers=[l._replace(step=jnp.array([length])) for l in slice.layers]
)
# Sample the actual output token.
rng_ = jnp.expand_dims(rng_, 0)
new_output = sample_token(rng_, lm_outputs, settings_slice)
# Update the KV cache/memory.
slice = jax.tree_map(pad_to_max_len, slice)
memory = insert_slice(memory, slice, length, i)
rng = jnp.expand_dims(rng, 0)
rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)
# Move the network outputs for this batch entry into the joint output tensor.
last_output = jax.tree_util.tree_map(
lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),
last_output,
new_output,
)
return rngs, last_output, memory, settings
sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))
prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))
new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))
forward_ = hk.without_apply_rng(hk.transform(hk_forward))
rng = jax.random.PRNGKey(42)
dummy_tokens = jnp.zeros((1, max_len), jnp.int32)
with runner.mesh:
shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)
self.params_sharding = jax.tree_util.tree_map_with_path(
apply_rules(runner.model.partition_rules()),
shapes,
)
ds = P("data")
ms = runner.model.model.get_memory_sharding()
self.sample_step = pjit.pjit(
sample_step_.apply,
in_shardings=(self.params_sharding, None, ds, ms, None),
out_shardings=(None, ds, ms),
donate_argnums=3,
)
self.prefill_memory = pjit.pjit(
functools.partial(prefill_memory_.apply),
in_shardings=(
self.params_sharding,
None,
ms,
None,
ds,
None,
None,
None,
None,
None,
),
out_shardings=(None, ds, ms, None),
donate_argnums=(2,),
)
self.new_memory = pjit.pjit(
new_memory_.apply,
static_argnums=(1, 2),
out_shardings=ms,
)
def run(self):
"""Generator that accepts prompts."""
runner = self.runner
mesh = runner.mesh
max_len = runner.model.sequence_len
batch_size = runner.batch_size
params = self.params
rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)
with mesh:
memory = self.new_memory(params, batch_size, max_len)
settings = SampleSettings( settings = SampleSettings(
temperature=np.zeros((batch_size,), dtype=np.float32), temperature=jnp.array([request.temperature]),
nucleus_p=np.zeros((batch_size,), dtype=np.float32), nucleus_p=jnp.array([request.nucleus_p]),
mask=np.ones((batch_size, self.vocab_size), dtype=np.int32), mask=jnp.ones(token_ids.shape, dtype=bool),
active=np.zeros((batch_size), dtype=np.int32), active=jnp.ones(token_ids.shape, dtype=bool),
)
last_output = SampleOutput(
token_id=np.zeros((batch_size, 1), dtype=np.int32),
prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),
top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),
top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),
) )
prompt = np.array([300, 400, 500, 600, 600, 700, 800]) for _ in range(request.max_len):
lm_outputs = self.runner.eval_forward(token_ids)
new_settings = SampleSettings( sample_output = sample_token(gen_rng, lm_outputs, settings)
temperature=np.float32(1), new_token = sample_output.token_id
nucleus_p=np.float32(1), token_ids = jnp.concatenate([token_ids, new_token], axis=-1)
mask=np.ones((self.vocab_size,), dtype=np.int32), if jnp.argmax(new_token) == 0:
active=np.zeros((), dtype=np.int32),
)
rng_seed = np.uint64(1)
for size in self.pad_sizes:
if size > runner.model.sequence_len:
break break
logger.info("Precompile {}".format(size))
prompt_len = len(prompt) return self.tokenizer.decode(token_ids.squeeze())
prompt = pad_to_size(prompt, size)
rngs, last_output, memory, settings = self.prefill_memory( def main():
params, runner = ModelRunner(
rngs, model=LanguageModelConfig(),
memory, checkpoint_path="path_to_checkpoint",
settings,
last_output,
prompt,
prompt_len,
rng_seed,
new_settings,
0,
) )
with runner.mesh: inference_runner = InferenceRunner(
logger.info("Compiling...") name="inference",
rngs, last_output, memory = self.sample_step( runner=runner,
params, rngs, last_output, memory, settings load="path_to_load",
tokenizer_path="path_to_tokenizer_model",
local_mesh_config=(1, 1),
between_hosts_config=(1, 1),
) )
logger.info("Done compiling.") inference_runner.initialize()
request = Request(
all_tokens = [] prompt="Sample text",
free_slots = list(range(batch_size)) temperature=0.7,
requests = [None] * batch_size nucleus_p=0.9,
first_output = [None] * batch_size
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
prev_token = last_output
step = 0
total_num_tokens = 0
total_num_sequences = 0
with mesh:
while True:
while free_slots:
request: Optional[Request] = yield
tokens = self.tokenizer.encode(request.prompt)
temperature = request.temperature
nucleus_p = request.nucleus_p
rng_seed = request.rng_seed
i = free_slots.pop()
prompt = np.array(tokens, dtype=np.int32)
prompt_len = len(prompt)
prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))
# All tokens are allowed.
mask = np.ones((self.vocab_size,), dtype=np.int32)
new_settings = SampleSettings(
temperature=np.float32(temperature),
nucleus_p=np.float32(nucleus_p),
mask=mask,
active=np.ones((), dtype=np.int32),
)
rng_seed = np.uint64(rng_seed)
rngs, last_output, memory, settings = self.prefill_memory(
params,
rngs,
memory,
settings,
last_output,
prompt,
prompt_len,
rng_seed,
new_settings,
i,
)
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
first_output[i] = last_output
requests[i] = request
total_num_sequences += 1
rngs, last_output, memory = self.sample_step(
params, rngs, last_output, memory, settings
)
total_num_tokens += batch_size - len(free_slots)
# prev_token should already be on the host.
prev_token = jax.tree_map(np.array, prev_token)
for i in range(batch_size):
if requests[i] is not None:
if first_output[i] is not None:
first_output_i = jax.tree_map(np.array, first_output[i])
all_tokens.append(int(first_output_i.token_id[i][0]))
first_output[i] = None
continue
all_tokens.append(int(prev_token.token_id[i][0]))
cont = len(all_tokens) < requests[i].max_len
if not cont:
output_str = self.tokenizer.decode(all_tokens)
requests[i] = None
free_slots.append(i)
all_tokens = []
settings = settings._replace(active=settings.active.at[i].set(0))
yield output_str
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
prev_token = last_output
step += 1
def make_mesh(
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
) -> jax.sharding.Mesh:
assert len(local_mesh_config) == 2
assert len(between_hosts_config) == 2
rank_logger.info("Detected %s devices in mesh", jax.device_count())
device_mesh = mesh_utils.create_hybrid_device_mesh(
local_mesh_config,
between_hosts_config,
devices=jax.devices(),
process_is_granule=True,
)
rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}"))
return jax.sharding.Mesh(device_mesh, ("data", "model"))
def sample_from_model(server, prompt, max_len, temperature):
next(server)
inp = Request(
prompt=prompt,
temperature=temperature,
nucleus_p=1.0,
rng_seed=42, rng_seed=42,
max_len=max_len, max_len=100,
) )
return server.send(inp) result = inference_runner.predict(request)
print(result)
if __name__ == "__main__":
main()