grok-1/runners.py

854 lines
35 KiB
Python
Raw Normal View History

2024-03-15 01:03:58 +03:00
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
import functools
import logging
import math
import re
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.experimental.pjit as pjit
import jax.numpy as jnp
import numpy as np
import sentencepiece
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
from jax.typing import ArrayLike
import checkpoint as xai_checkpoint
from model import (
LanguageModelConfig,
LanguageModelOutput,
TrainingState,
apply_rules,
Memory,
KVMemory,
)
logger = logging.getLogger(__name__)
rank_logger = logging.getLogger("rank")
TOP_K = 8
class SampleSettings(NamedTuple):
2024-03-19 00:15:48 +03:00
"""
A NamedTuple for storing settings used during the sampling process.
Attributes:
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
the model outputs more deterministic, while higher values increase diversity.
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
of plausible tokens.
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
for sampling in each batch element.
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
generation. Elements not in use are ignored in computations to save resources.
"""
2024-03-15 01:03:58 +03:00
temperature: ArrayLike
nucleus_p: ArrayLike
mask: ArrayLike
# Whether a given batch element is actively used. [B]
active: ArrayLike
class SampleOutput(NamedTuple):
2024-03-19 00:15:48 +03:00
"""
A NamedTuple for storing the output from the sampling process.
Attributes:
- token_id (ArrayLike): The sampled token ID for each batch element.
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
"""
2024-03-15 01:03:58 +03:00
token_id: ArrayLike
prob: ArrayLike
top_k_token_ids: ArrayLike
top_k_probs: ArrayLike
def insert_slice(memory: Memory, slice, length, i):
2024-03-19 00:15:48 +03:00
"""
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
object with a new slice that includes updated steps for each layer based on the provided length.
Parameters:
- memory (Memory): The original memory object to be updated.
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
- length (int): The step length to set for each KVMemory layer in the inserted slice.
- i (int): The index at which the slice is to be inserted into the memory.
Returns:
- Memory: The updated memory object with the new slice inserted at the specified index.
"""
2024-03-15 01:03:58 +03:00
slice = Memory(
layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length]))
for layer in slice.layers
],
)
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
memory, slice)
def pad_to_size(x, size):
2024-03-19 00:15:48 +03:00
"""
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
truncated from the left. If it is shorter, it will be right-padded with zeros.
Parameters:
- x (ArrayLike): The array to be padded or truncated.
- size (int): The target size for the array.
Returns:
- ArrayLike: The array adjusted to the specified size.
"""
2024-03-15 01:03:58 +03:00
if x.shape[0] > size:
# Left truncate if the context is too long.
x = x[-size:]
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
2024-03-19 00:15:48 +03:00
"""
Performs nucleus filtering on logits.
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
the quality of generated text.
Parameters:
- logits (jax.Array): The logits to filter.
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
Returns:
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
from consideration during sampling.
"""
2024-03-15 01:03:58 +03:00
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits)
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
threshold_largest_logits = jnp.take_along_axis(
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
)
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
mask = logits >= threshold_largest_logits
# Set unused logits to -inf.
logits = jnp.where(mask, logits, -1e10)
return logits
def sample_token(
rngs: jax.random.PRNGKey,
lm_outputs: LanguageModelOutput,
settings: SampleSettings,
) -> SampleOutput:
2024-03-19 00:15:48 +03:00
"""
Samples a token from the language model outputs using specified settings, including temperature and
nucleus probability filtering. This function also computes the probability of the sampled token and
identifies the top-k tokens and their probabilities.
Parameters:
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
nucleus probability, token masking, and active batch elements indicator.
Returns:
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
and the top-k tokens and their probabilities for each batch element.
"""
2024-03-15 01:03:58 +03:00
# Expand the settings shape to match the logit shape.
settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1].
mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V].
active=settings.active, # [B].
)
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
# Mask out all disallowed tokens by assigning them a near-zero probability.
logits = jnp.where(settings.mask, logits, -1e10)
# Mask out all tokens that don't fall into the p-th percentile.
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
new_token = jax.vmap(jax.random.categorical)(rngs, logits)
probabilities = jax.nn.softmax(logits)
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
token_prob = jnp.squeeze(token_prob, 1)
# Gather the top-k tokens and probabilities.
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
top_k_probs = jnp.squeeze(top_k_probs, 1)
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
return SampleOutput(
new_token,
token_prob,
top_k_token_ids,
top_k_probs,
)
@dataclass
class ModelRunner:
2024-03-19 00:15:48 +03:00
"""
Manages the execution of a language model, including initialization, forward passes, and state management.
Attributes:
- model (LanguageModelConfig): The configuration for the language model.
- bs_per_device (float): The batch size per device.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
- rng_seed (int): The initial random number generator seed.
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
Methods:
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
- init: Initializes the model parameters using provided data.
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
"""
2024-03-15 01:03:58 +03:00
model: LanguageModelConfig
bs_per_device: float = 2.0
load_rename_rules: Optional[list[tuple[str, str]]] = None
load_exclude_rules: Optional[list[str]] = None
rng_seed: int = 42 # Initial rng seed.
transform_forward: bool = False
checkpoint_path: str = ""
def make_forward_fn(self, mesh: Any):
2024-03-19 00:15:48 +03:00
"""
Creates and optionally transforms the forward function of the model. This method constructs a forward
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
and execution on the mesh.
Parameters:
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
Returns:
Callable: The forward function, potentially transformed by Haiku, ready for execution.
"""
2024-03-15 01:03:58 +03:00
def forward(tokens):
out = self.model.make(mesh=mesh)(tokens)
return out, None
if self.transform_forward:
forward = hk.transform(forward)
return forward
def initialize(
self,
init_data,
local_mesh_config: tuple[int, int],
between_hosts_config: tuple[int, int],
):
2024-03-19 00:15:48 +03:00
"""
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
the forward function for execution.
Parameters:
init_data: Initial data used for model initialization. This is typically dummy data matching the
expected input format and dimensions of the model.
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
are organized locally for parallel computation.
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
distributed setup, important for multi-host environments.
"""
2024-03-15 01:03:58 +03:00
num_replicas = math.prod(between_hosts_config)
self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices())
# Calculate the global batch size from the local batch size.
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
# Calculate the batch size per host from the global batch size.
self.local_batch_size = self.batch_size // jax.process_count()
self.local_mesh_config = local_mesh_config
self.between_hosts_config = between_hosts_config
rank_logger.info(
f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
)
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
self.forward = self.make_forward_fn(mesh=self.mesh)
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
self.eval_forward = self.make_forward_fn(mesh=self.mesh)
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
if self.transform_forward:
self.state_sharding = self.get_state_sharding(init_data)
rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
def init(self, rng: jax.Array, data) -> TrainingState:
2024-03-19 00:15:48 +03:00
"""
Initializes the model parameters using provided data and a random number generator state. This method
is only called when `transform_forward` is True, indicating that the forward function has been transformed
by Haiku and requires explicit initialization.
Parameters:
rng (jax.Array): The random number generator state for initializing model parameters.
data: The initial data used for parameter initialization, usually including input tokens and
possibly target tokens for supervised learning tasks.
Returns:
TrainingState: An instance of TrainingState containing the initialized model parameters.
"""
2024-03-15 01:03:58 +03:00
assert self.transform_forward
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data["inputs"])
return TrainingState(params=params)
def get_state_sharding(self, init_data):
2024-03-19 00:15:48 +03:00
"""
Determines the sharding configuration for model parameters based on initialization data. This method
evaluates the shape of model parameters during initialization to apply the partition rules specified
in the model's configuration, facilitating efficient distributed computation.
Parameters:
init_data: The initial data used for model initialization, which helps determine the appropriate
sharding configuration based on model parameter shapes.
Returns:
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
"""
2024-03-15 01:03:58 +03:00
assert self.transform_forward
rng = jax.random.PRNGKey(self.rng_seed)
rank_logger.info(f"partition rules: {self.model.partition_rules}")
with self.mesh:
shapes = jax.eval_shape(self.init, rng, init_data)
sharding = jax.tree_util.tree_map_with_path(
apply_rules(self.model.partition_rules()),
shapes,
)
return sharding
def load_or_init(
self,
init_data: Any,
from_checkpoint: bool = True,
init_fn: Optional[Callable] = None,
):
2024-03-19 00:15:48 +03:00
"""
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
method provides flexibility in starting the model from a known state or freshly initializing parameters
based on the model configuration and provided initial data.
Parameters:
init_data: Initial data used for model parameter initialization if needed.
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
init_fn (Optional[Callable]): An optional initialization function that overrides the default
initialization method if provided.
Returns:
The model state, either loaded from a checkpoint or newly initialized.
"""
2024-03-15 01:03:58 +03:00
rng = jax.random.PRNGKey(self.rng_seed)
if not self.checkpoint_path or not from_checkpoint:
rank_logger.info("Initializing model...")
with self.mesh:
if init_fn is not None:
state = init_fn(rng, init_data)
else:
assert self.transform_forward
state = self.init_fn(rng, init_data)
rank_logger.info("Model state is newly initialized.")
else:
with self.mesh:
if init_fn:
state_shapes = jax.eval_shape(init_fn, rng, init_data)
else:
assert self.transform_forward
state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
init_state = None
state = xai_checkpoint.restore(
checkpoint_path=self.checkpoint_path,
state_shapes=state_shapes,
mesh=self.mesh,
between_hosts_config=self.between_hosts_config,
state_sharding=self.state_sharding,
init_state=init_state,
params_only=True,
)
del init_state
return state
@dataclass
class Request:
2024-03-19 00:15:48 +03:00
"""
Data class for storing information about a single sampling request to the model.
Attributes:
prompt (str): The text prompt to generate text from.
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
max_len (int): The maximum length of the generated text sequence.
"""
2024-03-15 01:03:58 +03:00
prompt: str
temperature: float
nucleus_p: float
rng_seed: int
max_len: int
@dataclass
class InferenceRunner:
2024-03-19 00:15:48 +03:00
"""
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
prefilling the model memory, and running the actual text generation.
Attributes:
name (str): A name identifier for the inference runner.
runner (Any): An instance of ModelRunner that manages the underlying language model.
load (str): A path to the model checkpoint for loading.
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
Methods:
get_pad_bucket: Determines the appropriate padding size for a given input length.
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
"""
2024-03-15 01:03:58 +03:00
name: str
runner: Any
load: str
tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
local_mesh_config: Tuple[int, int] = (1, 1)
between_hosts_config: Tuple[int, int] = (1, 1)
pad_sizes: tuple[int] = (1024,)
def get_pad_bucket(self, size):
2024-03-19 00:15:48 +03:00
"""
Determines the appropriate padding size for a given input length. This method uses the predefined
pad sizes to find the smallest pad size that is larger than or equal to the given size.
Parameters:
size (int): The length of the input sequence to be padded.
Returns:
int: The selected pad size from the predefined pad sizes that best fits the input length.
"""
2024-03-15 01:03:58 +03:00
i = bisect.bisect_left(self.pad_sizes, size)
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
def initialize(self):
2024-03-19 00:15:48 +03:00
"""
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
model functions for different input sizes, and loading or initializing model parameters. This method
ensures that the inference runner is ready to generate text based on prompts.
"""
2024-03-15 01:03:58 +03:00
runner = self.runner
self.runner.transform_forward = True
dummy_data = dict(
inputs=np.zeros((1, 256), dtype=np.int32),
targets=np.zeros((1, 256), dtype=np.int32),
)
runner.initialize(
dummy_data,
local_mesh_config=self.local_mesh_config,
between_hosts_config=self.between_hosts_config,
)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
max_len = runner.model.sequence_len
self.vocab_size = self.runner.model.vocab_size
params = runner.load_or_init(dummy_data)
self.params = params
def pad_to_max_len(x):
if len(x.shape) > 1:
pad_width = max_len - x.shape[1]
return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])
else:
return x
@functools.lru_cache
def lm():
return runner.model.make(mesh=runner.mesh)
def hk_forward(
tokens,
memory=None,
length=None,
active=None,
) -> LanguageModelOutput:
if memory is not None:
assert active is not None
layers = []
for l in memory.layers:
# Reset steps to 0 for inactive requests to avoid unnecessary computations.
step = jnp.where(active, l.step, jnp.zeros_like(l.step))
layers.append(l._replace(step=step))
memory = memory._replace(layers=layers)
return lm()(tokens, memory, length=length)
def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):
rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)
lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)
sample_result = sample_token(rngs_, lm_outputs, settings)
return rngs, sample_result, lm_outputs.model_state
def hk_new_memory(batch_size, sequence_len):
return lm().init_memory(batch_size, sequence_len)
def hk_prefill_memory(
rngs,
memory,
settings,
last_output,
prompt,
length,
rng_seed,
new_settings,
i,
):
rng = jax.random.PRNGKey(seed=rng_seed)
rng, rng_ = jax.random.split(rng)
# Allocate new memory for this sample. The memory length is equal to the length of the
# prompt.
slice = hk_new_memory(1, prompt.shape[0])
# Move the settings for this individual batch entry into the joint settings tensor.
settings = jax.tree_map(
lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),
settings,
new_settings,
)
# Get the settings for the batch entry from the joint settings tensor.
settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)
# Process the first n-1 tokens of the prompt.
lm_outputs = hk_forward(
jnp.expand_dims(prompt, 0),
memory=slice,
length=jnp.expand_dims(length, 0),
active=settings_slice.active,
)
# The forward pass doesn't correctly set the `step` counter inside the memory. Manually
# override it so `hk_forward` uses the correct context length in the next call.
slice = lm_outputs.model_state
slice = slice._replace(
layers=[l._replace(step=jnp.array([length])) for l in slice.layers]
)
# Sample the actual output token.
rng_ = jnp.expand_dims(rng_, 0)
new_output = sample_token(rng_, lm_outputs, settings_slice)
# Update the KV cache/memory.
slice = jax.tree_map(pad_to_max_len, slice)
memory = insert_slice(memory, slice, length, i)
rng = jnp.expand_dims(rng, 0)
rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)
# Move the network outputs for this batch entry into the joint output tensor.
last_output = jax.tree_util.tree_map(
lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),
last_output,
new_output,
)
return rngs, last_output, memory, settings
sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))
prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))
new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))
forward_ = hk.without_apply_rng(hk.transform(hk_forward))
rng = jax.random.PRNGKey(42)
dummy_tokens = jnp.zeros((1, max_len), jnp.int32)
with runner.mesh:
shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)
self.params_sharding = jax.tree_util.tree_map_with_path(
apply_rules(runner.model.partition_rules()),
shapes,
)
ds = P("data")
ms = runner.model.model.get_memory_sharding()
self.sample_step = pjit.pjit(
sample_step_.apply,
in_shardings=(self.params_sharding, None, ds, ms, None),
out_shardings=(None, ds, ms),
donate_argnums=3,
)
self.prefill_memory = pjit.pjit(
functools.partial(prefill_memory_.apply),
in_shardings=(
self.params_sharding,
None,
ms,
None,
ds,
None,
None,
None,
None,
None,
),
out_shardings=(None, ds, ms, None),
donate_argnums=(2,),
)
self.new_memory = pjit.pjit(
new_memory_.apply,
static_argnums=(1, 2),
out_shardings=ms,
)
def run(self):
2024-03-19 00:15:48 +03:00
"""
Generator method that accepts prompts and manages the text generation process. This method handles
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
generate and return the generated text.
Yields:
str: The generated text for each prompt received. This method operates as a coroutine, yielding
generated text back to the caller and awaiting new prompts for further generation.
"""
2024-03-15 01:03:58 +03:00
runner = self.runner
mesh = runner.mesh
max_len = runner.model.sequence_len
batch_size = runner.batch_size
params = self.params
rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)
with mesh:
memory = self.new_memory(params, batch_size, max_len)
settings = SampleSettings(
temperature=np.zeros((batch_size,), dtype=np.float32),
nucleus_p=np.zeros((batch_size,), dtype=np.float32),
mask=np.ones((batch_size, self.vocab_size), dtype=np.int32),
active=np.zeros((batch_size), dtype=np.int32),
)
last_output = SampleOutput(
token_id=np.zeros((batch_size, 1), dtype=np.int32),
prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),
top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),
top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),
)
prompt = np.array([300, 400, 500, 600, 600, 700, 800])
new_settings = SampleSettings(
temperature=np.float32(1),
nucleus_p=np.float32(1),
mask=np.ones((self.vocab_size,), dtype=np.int32),
active=np.zeros((), dtype=np.int32),
)
rng_seed = np.uint64(1)
for size in self.pad_sizes:
if size > runner.model.sequence_len:
break
logger.info("Precompile {}".format(size))
prompt_len = len(prompt)
prompt = pad_to_size(prompt, size)
rngs, last_output, memory, settings = self.prefill_memory(
params,
rngs,
memory,
settings,
last_output,
prompt,
prompt_len,
rng_seed,
new_settings,
0,
)
with runner.mesh:
logger.info("Compiling...")
rngs, last_output, memory = self.sample_step(
params, rngs, last_output, memory, settings
)
logger.info("Done compiling.")
all_tokens = []
free_slots = list(range(batch_size))
requests = [None] * batch_size
first_output = [None] * batch_size
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
prev_token = last_output
step = 0
total_num_tokens = 0
total_num_sequences = 0
with mesh:
while True:
while free_slots:
request: Optional[Request] = yield
tokens = self.tokenizer.encode(request.prompt)
temperature = request.temperature
nucleus_p = request.nucleus_p
rng_seed = request.rng_seed
i = free_slots.pop()
prompt = np.array(tokens, dtype=np.int32)
prompt_len = len(prompt)
prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))
# All tokens are allowed.
mask = np.ones((self.vocab_size,), dtype=np.int32)
new_settings = SampleSettings(
temperature=np.float32(temperature),
nucleus_p=np.float32(nucleus_p),
mask=mask,
active=np.ones((), dtype=np.int32),
)
rng_seed = np.uint64(rng_seed)
rngs, last_output, memory, settings = self.prefill_memory(
params,
rngs,
memory,
settings,
last_output,
prompt,
prompt_len,
rng_seed,
new_settings,
i,
)
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
first_output[i] = last_output
requests[i] = request
total_num_sequences += 1
rngs, last_output, memory = self.sample_step(
params, rngs, last_output, memory, settings
)
total_num_tokens += batch_size - len(free_slots)
# prev_token should already be on the host.
prev_token = jax.tree_map(np.array, prev_token)
for i in range(batch_size):
if requests[i] is not None:
if first_output[i] is not None:
first_output_i = jax.tree_map(np.array, first_output[i])
all_tokens.append(int(first_output_i.token_id[i][0]))
first_output[i] = None
continue
all_tokens.append(int(prev_token.token_id[i][0]))
cont = len(all_tokens) < requests[i].max_len
if not cont:
output_str = self.tokenizer.decode(all_tokens)
requests[i] = None
free_slots.append(i)
all_tokens = []
settings = settings._replace(active=settings.active.at[i].set(0))
yield output_str
jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
prev_token = last_output
step += 1
def make_mesh(
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
) -> jax.sharding.Mesh:
2024-03-19 00:15:48 +03:00
"""
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
Parameters:
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
to how local devices are arranged and partitioned for computation.
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
different hosts, relevant in multi-host distributed setups.
Returns:
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
"""
2024-03-15 01:03:58 +03:00
assert len(local_mesh_config) == 2
assert len(between_hosts_config) == 2
rank_logger.info("Detected %s devices in mesh", jax.device_count())
device_mesh = mesh_utils.create_hybrid_device_mesh(
local_mesh_config,
between_hosts_config,
devices=jax.devices(),
process_is_granule=True,
)
rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}"))
return jax.sharding.Mesh(device_mesh, ("data", "model"))
def sample_from_model(server, prompt, max_len, temperature):
2024-03-19 00:15:48 +03:00
"""
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
to interact with a generator-based server that manages the state of the model and performs the actual
sampling. The server is expected to yield control back to this function, which then sends a request
object containing the prompt and settings, and waits for the generated output.
The function initializes the server (if not already done), constructs the request with the provided
parameters, and sends this request to the server. The sampled output, typically a continuation of the
prompt, is then received from the server.
Parameters:
- server (generator): The generator-based server that handles model inference. This server is expected
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
performing the token sampling.
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
encoded and fed into the model as part of the request.
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
This controls how long the model's output can be.
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
model's outputs more deterministic (and potentially more repetitive), while higher values increase
diversity at the cost of coherence.
Returns:
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
by `max_len`.
"""
2024-03-15 01:03:58 +03:00
next(server)
inp = Request(
prompt=prompt,
temperature=temperature,
nucleus_p=1.0,
rng_seed=42,
max_len=max_len,
)
return server.send(inp)