mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-25 21:19:52 +03:00
606 lines
21 KiB
Python
606 lines
21 KiB
Python
|
# Copyright 2024 X.AI Corp.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
|
||
|
import bisect
|
||
|
import functools
|
||
|
import logging
|
||
|
import math
|
||
|
import re
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Any, Callable, NamedTuple, Optional, Tuple
|
||
|
|
||
|
import haiku as hk
|
||
|
import jax
|
||
|
import jax.experimental.pjit as pjit
|
||
|
import jax.numpy as jnp
|
||
|
import numpy as np
|
||
|
import sentencepiece
|
||
|
from jax.experimental import mesh_utils
|
||
|
from jax.sharding import PartitionSpec as P
|
||
|
from jax.typing import ArrayLike
|
||
|
|
||
|
import checkpoint as xai_checkpoint
|
||
|
from model import (
|
||
|
LanguageModelConfig,
|
||
|
LanguageModelOutput,
|
||
|
TrainingState,
|
||
|
apply_rules,
|
||
|
Memory,
|
||
|
KVMemory,
|
||
|
)
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
rank_logger = logging.getLogger("rank")
|
||
|
|
||
|
TOP_K = 8
|
||
|
|
||
|
|
||
|
class SampleSettings(NamedTuple):
|
||
|
temperature: ArrayLike
|
||
|
nucleus_p: ArrayLike
|
||
|
mask: ArrayLike
|
||
|
# Whether a given batch element is actively used. [B]
|
||
|
active: ArrayLike
|
||
|
|
||
|
|
||
|
class SampleOutput(NamedTuple):
|
||
|
token_id: ArrayLike
|
||
|
prob: ArrayLike
|
||
|
top_k_token_ids: ArrayLike
|
||
|
top_k_probs: ArrayLike
|
||
|
|
||
|
|
||
|
def insert_slice(memory: Memory, slice, length, i):
|
||
|
slice = Memory(
|
||
|
layers=[
|
||
|
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
||
|
for layer in slice.layers
|
||
|
],
|
||
|
)
|
||
|
|
||
|
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
|
||
|
memory, slice)
|
||
|
|
||
|
|
||
|
def pad_to_size(x, size):
|
||
|
if x.shape[0] > size:
|
||
|
# Left truncate if the context is too long.
|
||
|
x = x[-size:]
|
||
|
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
|
||
|
|
||
|
|
||
|
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
||
|
"""Performs nucleus filtering on logits."""
|
||
|
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
||
|
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
||
|
sorted_probs = jax.nn.softmax(sorted_logits)
|
||
|
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
|
||
|
threshold_largest_logits = jnp.take_along_axis(
|
||
|
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
|
||
|
)
|
||
|
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
|
||
|
mask = logits >= threshold_largest_logits
|
||
|
# Set unused logits to -inf.
|
||
|
logits = jnp.where(mask, logits, -1e10)
|
||
|
return logits
|
||
|
|
||
|
|
||
|
def sample_token(
|
||
|
rngs: jax.random.PRNGKey,
|
||
|
lm_outputs: LanguageModelOutput,
|
||
|
settings: SampleSettings,
|
||
|
) -> SampleOutput:
|
||
|
# Expand the settings shape to match the logit shape.
|
||
|
settings = SampleSettings(
|
||
|
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
|
||
|
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1].
|
||
|
mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V].
|
||
|
active=settings.active, # [B].
|
||
|
)
|
||
|
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
|
||
|
# Mask out all disallowed tokens by assigning them a near-zero probability.
|
||
|
logits = jnp.where(settings.mask, logits, -1e10)
|
||
|
# Mask out all tokens that don't fall into the p-th percentile.
|
||
|
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
|
||
|
|
||
|
new_token = jax.vmap(jax.random.categorical)(rngs, logits)
|
||
|
|
||
|
probabilities = jax.nn.softmax(logits)
|
||
|
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
|
||
|
token_prob = jnp.squeeze(token_prob, 1)
|
||
|
|
||
|
# Gather the top-k tokens and probabilities.
|
||
|
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
|
||
|
top_k_probs = jnp.squeeze(top_k_probs, 1)
|
||
|
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
|
||
|
return SampleOutput(
|
||
|
new_token,
|
||
|
token_prob,
|
||
|
top_k_token_ids,
|
||
|
top_k_probs,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ModelRunner:
|
||
|
model: LanguageModelConfig
|
||
|
|
||
|
bs_per_device: float = 2.0
|
||
|
|
||
|
load_rename_rules: Optional[list[tuple[str, str]]] = None
|
||
|
load_exclude_rules: Optional[list[str]] = None
|
||
|
|
||
|
rng_seed: int = 42 # Initial rng seed.
|
||
|
transform_forward: bool = False
|
||
|
|
||
|
checkpoint_path: str = ""
|
||
|
|
||
|
def make_forward_fn(self, mesh: Any):
|
||
|
def forward(tokens):
|
||
|
out = self.model.make(mesh=mesh)(tokens)
|
||
|
return out, None
|
||
|
|
||
|
if self.transform_forward:
|
||
|
forward = hk.transform(forward)
|
||
|
return forward
|
||
|
|
||
|
def initialize(
|
||
|
self,
|
||
|
init_data,
|
||
|
local_mesh_config: tuple[int, int],
|
||
|
between_hosts_config: tuple[int, int],
|
||
|
):
|
||
|
num_replicas = math.prod(between_hosts_config)
|
||
|
self.model.initialize()
|
||
|
self.model.fprop_dtype = jnp.bfloat16
|
||
|
num_local_gpus = len(jax.local_devices())
|
||
|
|
||
|
# Calculate the global batch size from the local batch size.
|
||
|
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
|
||
|
|
||
|
# Calculate the batch size per host from the global batch size.
|
||
|
self.local_batch_size = self.batch_size // jax.process_count()
|
||
|
|
||
|
self.local_mesh_config = local_mesh_config
|
||
|
self.between_hosts_config = between_hosts_config
|
||
|
rank_logger.info(
|
||
|
f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
|
||
|
)
|
||
|
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
|
||
|
self.forward = self.make_forward_fn(mesh=self.mesh)
|
||
|
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
|
||
|
|
||
|
self.eval_forward = self.make_forward_fn(mesh=self.mesh)
|
||
|
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
|
||
|
|
||
|
if self.transform_forward:
|
||
|
self.state_sharding = self.get_state_sharding(init_data)
|
||
|
rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
|
||
|
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
||
|
|
||
|
def init(self, rng: jax.Array, data) -> TrainingState:
|
||
|
assert self.transform_forward
|
||
|
rng, init_rng = jax.random.split(rng)
|
||
|
params = self.forward.init(init_rng, data["inputs"])
|
||
|
return TrainingState(params=params)
|
||
|
|
||
|
def get_state_sharding(self, init_data):
|
||
|
assert self.transform_forward
|
||
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||
|
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
||
|
|
||
|
with self.mesh:
|
||
|
shapes = jax.eval_shape(self.init, rng, init_data)
|
||
|
sharding = jax.tree_util.tree_map_with_path(
|
||
|
apply_rules(self.model.partition_rules()),
|
||
|
shapes,
|
||
|
)
|
||
|
return sharding
|
||
|
|
||
|
def load_or_init(
|
||
|
self,
|
||
|
init_data: Any,
|
||
|
from_checkpoint: bool = True,
|
||
|
init_fn: Optional[Callable] = None,
|
||
|
):
|
||
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||
|
|
||
|
if not self.checkpoint_path or not from_checkpoint:
|
||
|
rank_logger.info("Initializing model...")
|
||
|
with self.mesh:
|
||
|
if init_fn is not None:
|
||
|
state = init_fn(rng, init_data)
|
||
|
else:
|
||
|
assert self.transform_forward
|
||
|
state = self.init_fn(rng, init_data)
|
||
|
rank_logger.info("Model state is newly initialized.")
|
||
|
else:
|
||
|
with self.mesh:
|
||
|
if init_fn:
|
||
|
state_shapes = jax.eval_shape(init_fn, rng, init_data)
|
||
|
else:
|
||
|
assert self.transform_forward
|
||
|
state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
|
||
|
init_state = None
|
||
|
|
||
|
state = xai_checkpoint.restore(
|
||
|
checkpoint_path=self.checkpoint_path,
|
||
|
state_shapes=state_shapes,
|
||
|
mesh=self.mesh,
|
||
|
between_hosts_config=self.between_hosts_config,
|
||
|
state_sharding=self.state_sharding,
|
||
|
init_state=init_state,
|
||
|
params_only=True,
|
||
|
)
|
||
|
|
||
|
del init_state
|
||
|
return state
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Request:
|
||
|
prompt: str
|
||
|
temperature: float
|
||
|
nucleus_p: float
|
||
|
rng_seed: int
|
||
|
max_len: int
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class InferenceRunner:
|
||
|
name: str
|
||
|
runner: Any
|
||
|
load: str
|
||
|
tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
|
||
|
local_mesh_config: Tuple[int, int] = (1, 1)
|
||
|
between_hosts_config: Tuple[int, int] = (1, 1)
|
||
|
pad_sizes: tuple[int] = (1024,)
|
||
|
|
||
|
def get_pad_bucket(self, size):
|
||
|
i = bisect.bisect_left(self.pad_sizes, size)
|
||
|
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
||
|
|
||
|
def initialize(self):
|
||
|
runner = self.runner
|
||
|
self.runner.transform_forward = True
|
||
|
dummy_data = dict(
|
||
|
inputs=np.zeros((1, 256), dtype=np.int32),
|
||
|
targets=np.zeros((1, 256), dtype=np.int32),
|
||
|
)
|
||
|
runner.initialize(
|
||
|
dummy_data,
|
||
|
local_mesh_config=self.local_mesh_config,
|
||
|
between_hosts_config=self.between_hosts_config,
|
||
|
)
|
||
|
|
||
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
|
||
|
|
||
|
max_len = runner.model.sequence_len
|
||
|
|
||
|
self.vocab_size = self.runner.model.vocab_size
|
||
|
|
||
|
params = runner.load_or_init(dummy_data)
|
||
|
self.params = params
|
||
|
|
||
|
def pad_to_max_len(x):
|
||
|
if len(x.shape) > 1:
|
||
|
pad_width = max_len - x.shape[1]
|
||
|
return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])
|
||
|
else:
|
||
|
return x
|
||
|
|
||
|
@functools.lru_cache
|
||
|
def lm():
|
||
|
return runner.model.make(mesh=runner.mesh)
|
||
|
|
||
|
def hk_forward(
|
||
|
tokens,
|
||
|
memory=None,
|
||
|
length=None,
|
||
|
active=None,
|
||
|
) -> LanguageModelOutput:
|
||
|
if memory is not None:
|
||
|
assert active is not None
|
||
|
layers = []
|
||
|
for l in memory.layers:
|
||
|
# Reset steps to 0 for inactive requests to avoid unnecessary computations.
|
||
|
step = jnp.where(active, l.step, jnp.zeros_like(l.step))
|
||
|
layers.append(l._replace(step=step))
|
||
|
memory = memory._replace(layers=layers)
|
||
|
return lm()(tokens, memory, length=length)
|
||
|
|
||
|
def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):
|
||
|
rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)
|
||
|
lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)
|
||
|
sample_result = sample_token(rngs_, lm_outputs, settings)
|
||
|
return rngs, sample_result, lm_outputs.model_state
|
||
|
|
||
|
def hk_new_memory(batch_size, sequence_len):
|
||
|
return lm().init_memory(batch_size, sequence_len)
|
||
|
|
||
|
def hk_prefill_memory(
|
||
|
rngs,
|
||
|
memory,
|
||
|
settings,
|
||
|
last_output,
|
||
|
prompt,
|
||
|
length,
|
||
|
rng_seed,
|
||
|
new_settings,
|
||
|
i,
|
||
|
):
|
||
|
rng = jax.random.PRNGKey(seed=rng_seed)
|
||
|
rng, rng_ = jax.random.split(rng)
|
||
|
|
||
|
# Allocate new memory for this sample. The memory length is equal to the length of the
|
||
|
# prompt.
|
||
|
slice = hk_new_memory(1, prompt.shape[0])
|
||
|
|
||
|
# Move the settings for this individual batch entry into the joint settings tensor.
|
||
|
settings = jax.tree_map(
|
||
|
lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),
|
||
|
settings,
|
||
|
new_settings,
|
||
|
)
|
||
|
|
||
|
# Get the settings for the batch entry from the joint settings tensor.
|
||
|
settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)
|
||
|
|
||
|
# Process the first n-1 tokens of the prompt.
|
||
|
lm_outputs = hk_forward(
|
||
|
jnp.expand_dims(prompt, 0),
|
||
|
memory=slice,
|
||
|
length=jnp.expand_dims(length, 0),
|
||
|
active=settings_slice.active,
|
||
|
)
|
||
|
|
||
|
# The forward pass doesn't correctly set the `step` counter inside the memory. Manually
|
||
|
# override it so `hk_forward` uses the correct context length in the next call.
|
||
|
slice = lm_outputs.model_state
|
||
|
slice = slice._replace(
|
||
|
layers=[l._replace(step=jnp.array([length])) for l in slice.layers]
|
||
|
)
|
||
|
|
||
|
# Sample the actual output token.
|
||
|
rng_ = jnp.expand_dims(rng_, 0)
|
||
|
new_output = sample_token(rng_, lm_outputs, settings_slice)
|
||
|
|
||
|
# Update the KV cache/memory.
|
||
|
slice = jax.tree_map(pad_to_max_len, slice)
|
||
|
memory = insert_slice(memory, slice, length, i)
|
||
|
|
||
|
rng = jnp.expand_dims(rng, 0)
|
||
|
rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)
|
||
|
|
||
|
# Move the network outputs for this batch entry into the joint output tensor.
|
||
|
last_output = jax.tree_util.tree_map(
|
||
|
lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),
|
||
|
last_output,
|
||
|
new_output,
|
||
|
)
|
||
|
return rngs, last_output, memory, settings
|
||
|
|
||
|
sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))
|
||
|
prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))
|
||
|
new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))
|
||
|
forward_ = hk.without_apply_rng(hk.transform(hk_forward))
|
||
|
|
||
|
rng = jax.random.PRNGKey(42)
|
||
|
dummy_tokens = jnp.zeros((1, max_len), jnp.int32)
|
||
|
|
||
|
with runner.mesh:
|
||
|
shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)
|
||
|
|
||
|
self.params_sharding = jax.tree_util.tree_map_with_path(
|
||
|
apply_rules(runner.model.partition_rules()),
|
||
|
shapes,
|
||
|
)
|
||
|
|
||
|
ds = P("data")
|
||
|
ms = runner.model.model.get_memory_sharding()
|
||
|
self.sample_step = pjit.pjit(
|
||
|
sample_step_.apply,
|
||
|
in_shardings=(self.params_sharding, None, ds, ms, None),
|
||
|
out_shardings=(None, ds, ms),
|
||
|
donate_argnums=3,
|
||
|
)
|
||
|
self.prefill_memory = pjit.pjit(
|
||
|
functools.partial(prefill_memory_.apply),
|
||
|
in_shardings=(
|
||
|
self.params_sharding,
|
||
|
None,
|
||
|
ms,
|
||
|
None,
|
||
|
ds,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
),
|
||
|
out_shardings=(None, ds, ms, None),
|
||
|
donate_argnums=(2,),
|
||
|
)
|
||
|
self.new_memory = pjit.pjit(
|
||
|
new_memory_.apply,
|
||
|
static_argnums=(1, 2),
|
||
|
out_shardings=ms,
|
||
|
)
|
||
|
|
||
|
def run(self):
|
||
|
"""Generator that accepts prompts."""
|
||
|
runner = self.runner
|
||
|
mesh = runner.mesh
|
||
|
max_len = runner.model.sequence_len
|
||
|
batch_size = runner.batch_size
|
||
|
params = self.params
|
||
|
rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)
|
||
|
with mesh:
|
||
|
memory = self.new_memory(params, batch_size, max_len)
|
||
|
settings = SampleSettings(
|
||
|
temperature=np.zeros((batch_size,), dtype=np.float32),
|
||
|
nucleus_p=np.zeros((batch_size,), dtype=np.float32),
|
||
|
mask=np.ones((batch_size, self.vocab_size), dtype=np.int32),
|
||
|
active=np.zeros((batch_size), dtype=np.int32),
|
||
|
)
|
||
|
last_output = SampleOutput(
|
||
|
token_id=np.zeros((batch_size, 1), dtype=np.int32),
|
||
|
prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),
|
||
|
top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),
|
||
|
top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),
|
||
|
)
|
||
|
|
||
|
prompt = np.array([300, 400, 500, 600, 600, 700, 800])
|
||
|
|
||
|
new_settings = SampleSettings(
|
||
|
temperature=np.float32(1),
|
||
|
nucleus_p=np.float32(1),
|
||
|
mask=np.ones((self.vocab_size,), dtype=np.int32),
|
||
|
active=np.zeros((), dtype=np.int32),
|
||
|
)
|
||
|
rng_seed = np.uint64(1)
|
||
|
|
||
|
for size in self.pad_sizes:
|
||
|
if size > runner.model.sequence_len:
|
||
|
break
|
||
|
logger.info("Precompile {}".format(size))
|
||
|
prompt_len = len(prompt)
|
||
|
prompt = pad_to_size(prompt, size)
|
||
|
rngs, last_output, memory, settings = self.prefill_memory(
|
||
|
params,
|
||
|
rngs,
|
||
|
memory,
|
||
|
settings,
|
||
|
last_output,
|
||
|
prompt,
|
||
|
prompt_len,
|
||
|
rng_seed,
|
||
|
new_settings,
|
||
|
0,
|
||
|
)
|
||
|
with runner.mesh:
|
||
|
logger.info("Compiling...")
|
||
|
rngs, last_output, memory = self.sample_step(
|
||
|
params, rngs, last_output, memory, settings
|
||
|
)
|
||
|
logger.info("Done compiling.")
|
||
|
|
||
|
all_tokens = []
|
||
|
free_slots = list(range(batch_size))
|
||
|
requests = [None] * batch_size
|
||
|
first_output = [None] * batch_size
|
||
|
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
|
||
|
prev_token = last_output
|
||
|
step = 0
|
||
|
total_num_tokens = 0
|
||
|
total_num_sequences = 0
|
||
|
with mesh:
|
||
|
while True:
|
||
|
while free_slots:
|
||
|
request: Optional[Request] = yield
|
||
|
tokens = self.tokenizer.encode(request.prompt)
|
||
|
temperature = request.temperature
|
||
|
nucleus_p = request.nucleus_p
|
||
|
rng_seed = request.rng_seed
|
||
|
|
||
|
i = free_slots.pop()
|
||
|
prompt = np.array(tokens, dtype=np.int32)
|
||
|
prompt_len = len(prompt)
|
||
|
prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))
|
||
|
# All tokens are allowed.
|
||
|
mask = np.ones((self.vocab_size,), dtype=np.int32)
|
||
|
|
||
|
new_settings = SampleSettings(
|
||
|
temperature=np.float32(temperature),
|
||
|
nucleus_p=np.float32(nucleus_p),
|
||
|
mask=mask,
|
||
|
active=np.ones((), dtype=np.int32),
|
||
|
)
|
||
|
rng_seed = np.uint64(rng_seed)
|
||
|
rngs, last_output, memory, settings = self.prefill_memory(
|
||
|
params,
|
||
|
rngs,
|
||
|
memory,
|
||
|
settings,
|
||
|
last_output,
|
||
|
prompt,
|
||
|
prompt_len,
|
||
|
rng_seed,
|
||
|
new_settings,
|
||
|
i,
|
||
|
)
|
||
|
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
|
||
|
first_output[i] = last_output
|
||
|
requests[i] = request
|
||
|
total_num_sequences += 1
|
||
|
|
||
|
rngs, last_output, memory = self.sample_step(
|
||
|
params, rngs, last_output, memory, settings
|
||
|
)
|
||
|
total_num_tokens += batch_size - len(free_slots)
|
||
|
|
||
|
# prev_token should already be on the host.
|
||
|
prev_token = jax.tree_map(np.array, prev_token)
|
||
|
for i in range(batch_size):
|
||
|
if requests[i] is not None:
|
||
|
if first_output[i] is not None:
|
||
|
first_output_i = jax.tree_map(np.array, first_output[i])
|
||
|
all_tokens.append(int(first_output_i.token_id[i][0]))
|
||
|
first_output[i] = None
|
||
|
continue
|
||
|
|
||
|
all_tokens.append(int(prev_token.token_id[i][0]))
|
||
|
cont = len(all_tokens) < requests[i].max_len
|
||
|
|
||
|
if not cont:
|
||
|
output_str = self.tokenizer.decode(all_tokens)
|
||
|
requests[i] = None
|
||
|
free_slots.append(i)
|
||
|
all_tokens = []
|
||
|
settings = settings._replace(active=settings.active.at[i].set(0))
|
||
|
yield output_str
|
||
|
|
||
|
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
|
||
|
prev_token = last_output
|
||
|
step += 1
|
||
|
|
||
|
|
||
|
def make_mesh(
|
||
|
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
|
||
|
) -> jax.sharding.Mesh:
|
||
|
assert len(local_mesh_config) == 2
|
||
|
assert len(between_hosts_config) == 2
|
||
|
rank_logger.info("Detected %s devices in mesh", jax.device_count())
|
||
|
device_mesh = mesh_utils.create_hybrid_device_mesh(
|
||
|
local_mesh_config,
|
||
|
between_hosts_config,
|
||
|
devices=jax.devices(),
|
||
|
process_is_granule=True,
|
||
|
)
|
||
|
rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}"))
|
||
|
return jax.sharding.Mesh(device_mesh, ("data", "model"))
|
||
|
|
||
|
|
||
|
def sample_from_model(server, prompt, max_len, temperature):
|
||
|
next(server)
|
||
|
inp = Request(
|
||
|
prompt=prompt,
|
||
|
temperature=temperature,
|
||
|
nucleus_p=1.0,
|
||
|
rng_seed=42,
|
||
|
max_len=max_len,
|
||
|
)
|
||
|
return server.send(inp)
|