mirror of
https://github.com/xai-org/grok-1.git
synced 2025-02-21 13:59:59 +03:00
Updated docstrings for runners.py
This commit is contained in:
parent
3f6fb5f4aa
commit
2d92966a1f
252
runners.py
252
runners.py
@ -48,6 +48,20 @@ TOP_K = 8
|
||||
|
||||
|
||||
class SampleSettings(NamedTuple):
|
||||
"""
|
||||
A NamedTuple for storing settings used during the sampling process.
|
||||
|
||||
Attributes:
|
||||
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
|
||||
the model outputs more deterministic, while higher values increase diversity.
|
||||
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
|
||||
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
|
||||
of plausible tokens.
|
||||
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
|
||||
for sampling in each batch element.
|
||||
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
|
||||
generation. Elements not in use are ignored in computations to save resources.
|
||||
"""
|
||||
temperature: ArrayLike
|
||||
nucleus_p: ArrayLike
|
||||
mask: ArrayLike
|
||||
@ -56,6 +70,15 @@ class SampleSettings(NamedTuple):
|
||||
|
||||
|
||||
class SampleOutput(NamedTuple):
|
||||
"""
|
||||
A NamedTuple for storing the output from the sampling process.
|
||||
|
||||
Attributes:
|
||||
- token_id (ArrayLike): The sampled token ID for each batch element.
|
||||
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
|
||||
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
|
||||
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
|
||||
"""
|
||||
token_id: ArrayLike
|
||||
prob: ArrayLike
|
||||
top_k_token_ids: ArrayLike
|
||||
@ -63,6 +86,19 @@ class SampleOutput(NamedTuple):
|
||||
|
||||
|
||||
def insert_slice(memory: Memory, slice, length, i):
|
||||
"""
|
||||
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
|
||||
object with a new slice that includes updated steps for each layer based on the provided length.
|
||||
|
||||
Parameters:
|
||||
- memory (Memory): The original memory object to be updated.
|
||||
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
|
||||
- length (int): The step length to set for each KVMemory layer in the inserted slice.
|
||||
- i (int): The index at which the slice is to be inserted into the memory.
|
||||
|
||||
Returns:
|
||||
- Memory: The updated memory object with the new slice inserted at the specified index.
|
||||
"""
|
||||
slice = Memory(
|
||||
layers=[
|
||||
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
||||
@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i):
|
||||
|
||||
|
||||
def pad_to_size(x, size):
|
||||
"""
|
||||
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
|
||||
truncated from the left. If it is shorter, it will be right-padded with zeros.
|
||||
|
||||
Parameters:
|
||||
- x (ArrayLike): The array to be padded or truncated.
|
||||
- size (int): The target size for the array.
|
||||
|
||||
Returns:
|
||||
- ArrayLike: The array adjusted to the specified size.
|
||||
"""
|
||||
if x.shape[0] > size:
|
||||
# Left truncate if the context is too long.
|
||||
x = x[-size:]
|
||||
@ -82,7 +129,20 @@ def pad_to_size(x, size):
|
||||
|
||||
|
||||
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
||||
"""Performs nucleus filtering on logits."""
|
||||
"""
|
||||
Performs nucleus filtering on logits.
|
||||
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
|
||||
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
|
||||
the quality of generated text.
|
||||
|
||||
Parameters:
|
||||
- logits (jax.Array): The logits to filter.
|
||||
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
|
||||
|
||||
Returns:
|
||||
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
|
||||
from consideration during sampling.
|
||||
"""
|
||||
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
||||
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
||||
sorted_probs = jax.nn.softmax(sorted_logits)
|
||||
@ -102,6 +162,21 @@ def sample_token(
|
||||
lm_outputs: LanguageModelOutput,
|
||||
settings: SampleSettings,
|
||||
) -> SampleOutput:
|
||||
"""
|
||||
Samples a token from the language model outputs using specified settings, including temperature and
|
||||
nucleus probability filtering. This function also computes the probability of the sampled token and
|
||||
identifies the top-k tokens and their probabilities.
|
||||
|
||||
Parameters:
|
||||
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
|
||||
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
|
||||
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
|
||||
nucleus probability, token masking, and active batch elements indicator.
|
||||
|
||||
Returns:
|
||||
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
|
||||
and the top-k tokens and their probabilities for each batch element.
|
||||
"""
|
||||
# Expand the settings shape to match the logit shape.
|
||||
settings = SampleSettings(
|
||||
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
|
||||
@ -135,6 +210,25 @@ def sample_token(
|
||||
|
||||
@dataclass
|
||||
class ModelRunner:
|
||||
"""
|
||||
Manages the execution of a language model, including initialization, forward passes, and state management.
|
||||
|
||||
Attributes:
|
||||
- model (LanguageModelConfig): The configuration for the language model.
|
||||
- bs_per_device (float): The batch size per device.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
|
||||
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
|
||||
- rng_seed (int): The initial random number generator seed.
|
||||
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
|
||||
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
|
||||
|
||||
Methods:
|
||||
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
|
||||
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
|
||||
- init: Initializes the model parameters using provided data.
|
||||
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
|
||||
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
|
||||
"""
|
||||
model: LanguageModelConfig
|
||||
|
||||
bs_per_device: float = 2.0
|
||||
@ -148,6 +242,18 @@ class ModelRunner:
|
||||
checkpoint_path: str = ""
|
||||
|
||||
def make_forward_fn(self, mesh: Any):
|
||||
"""
|
||||
Creates and optionally transforms the forward function of the model. This method constructs a forward
|
||||
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
|
||||
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
|
||||
and execution on the mesh.
|
||||
|
||||
Parameters:
|
||||
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
|
||||
|
||||
Returns:
|
||||
Callable: The forward function, potentially transformed by Haiku, ready for execution.
|
||||
"""
|
||||
def forward(tokens):
|
||||
out = self.model.make(mesh=mesh)(tokens)
|
||||
return out, None
|
||||
@ -162,6 +268,20 @@ class ModelRunner:
|
||||
local_mesh_config: tuple[int, int],
|
||||
between_hosts_config: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
|
||||
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
|
||||
the forward function for execution.
|
||||
|
||||
Parameters:
|
||||
init_data: Initial data used for model initialization. This is typically dummy data matching the
|
||||
expected input format and dimensions of the model.
|
||||
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
|
||||
are organized locally for parallel computation.
|
||||
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
|
||||
distributed setup, important for multi-host environments.
|
||||
|
||||
"""
|
||||
num_replicas = math.prod(between_hosts_config)
|
||||
self.model.initialize()
|
||||
self.model.fprop_dtype = jnp.bfloat16
|
||||
@ -191,12 +311,37 @@ class ModelRunner:
|
||||
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
||||
|
||||
def init(self, rng: jax.Array, data) -> TrainingState:
|
||||
"""
|
||||
Initializes the model parameters using provided data and a random number generator state. This method
|
||||
is only called when `transform_forward` is True, indicating that the forward function has been transformed
|
||||
by Haiku and requires explicit initialization.
|
||||
|
||||
Parameters:
|
||||
rng (jax.Array): The random number generator state for initializing model parameters.
|
||||
data: The initial data used for parameter initialization, usually including input tokens and
|
||||
possibly target tokens for supervised learning tasks.
|
||||
|
||||
Returns:
|
||||
TrainingState: An instance of TrainingState containing the initialized model parameters.
|
||||
"""
|
||||
assert self.transform_forward
|
||||
rng, init_rng = jax.random.split(rng)
|
||||
params = self.forward.init(init_rng, data["inputs"])
|
||||
return TrainingState(params=params)
|
||||
|
||||
def get_state_sharding(self, init_data):
|
||||
"""
|
||||
Determines the sharding configuration for model parameters based on initialization data. This method
|
||||
evaluates the shape of model parameters during initialization to apply the partition rules specified
|
||||
in the model's configuration, facilitating efficient distributed computation.
|
||||
|
||||
Parameters:
|
||||
init_data: The initial data used for model initialization, which helps determine the appropriate
|
||||
sharding configuration based on model parameter shapes.
|
||||
|
||||
Returns:
|
||||
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
|
||||
"""
|
||||
assert self.transform_forward
|
||||
rng = jax.random.PRNGKey(self.rng_seed)
|
||||
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
||||
@ -215,6 +360,20 @@ class ModelRunner:
|
||||
from_checkpoint: bool = True,
|
||||
init_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
|
||||
method provides flexibility in starting the model from a known state or freshly initializing parameters
|
||||
based on the model configuration and provided initial data.
|
||||
|
||||
Parameters:
|
||||
init_data: Initial data used for model parameter initialization if needed.
|
||||
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
|
||||
init_fn (Optional[Callable]): An optional initialization function that overrides the default
|
||||
initialization method if provided.
|
||||
|
||||
Returns:
|
||||
The model state, either loaded from a checkpoint or newly initialized.
|
||||
"""
|
||||
rng = jax.random.PRNGKey(self.rng_seed)
|
||||
|
||||
if not self.checkpoint_path or not from_checkpoint:
|
||||
@ -251,6 +410,16 @@ class ModelRunner:
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
"""
|
||||
Data class for storing information about a single sampling request to the model.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The text prompt to generate text from.
|
||||
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
|
||||
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
|
||||
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
|
||||
max_len (int): The maximum length of the generated text sequence.
|
||||
"""
|
||||
prompt: str
|
||||
temperature: float
|
||||
nucleus_p: float
|
||||
@ -260,6 +429,24 @@ class Request:
|
||||
|
||||
@dataclass
|
||||
class InferenceRunner:
|
||||
"""
|
||||
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
|
||||
prefilling the model memory, and running the actual text generation.
|
||||
|
||||
Attributes:
|
||||
name (str): A name identifier for the inference runner.
|
||||
runner (Any): An instance of ModelRunner that manages the underlying language model.
|
||||
load (str): A path to the model checkpoint for loading.
|
||||
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
|
||||
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
|
||||
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
|
||||
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
|
||||
|
||||
Methods:
|
||||
get_pad_bucket: Determines the appropriate padding size for a given input length.
|
||||
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
|
||||
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
|
||||
"""
|
||||
name: str
|
||||
runner: Any
|
||||
load: str
|
||||
@ -269,10 +456,25 @@ class InferenceRunner:
|
||||
pad_sizes: tuple[int] = (1024,)
|
||||
|
||||
def get_pad_bucket(self, size):
|
||||
"""
|
||||
Determines the appropriate padding size for a given input length. This method uses the predefined
|
||||
pad sizes to find the smallest pad size that is larger than or equal to the given size.
|
||||
|
||||
Parameters:
|
||||
size (int): The length of the input sequence to be padded.
|
||||
|
||||
Returns:
|
||||
int: The selected pad size from the predefined pad sizes that best fits the input length.
|
||||
"""
|
||||
i = bisect.bisect_left(self.pad_sizes, size)
|
||||
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
|
||||
model functions for different input sizes, and loading or initializing model parameters. This method
|
||||
ensures that the inference runner is ready to generate text based on prompts.
|
||||
"""
|
||||
runner = self.runner
|
||||
self.runner.transform_forward = True
|
||||
dummy_data = dict(
|
||||
@ -440,7 +642,15 @@ class InferenceRunner:
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""Generator that accepts prompts."""
|
||||
"""
|
||||
Generator method that accepts prompts and manages the text generation process. This method handles
|
||||
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
|
||||
generate and return the generated text.
|
||||
|
||||
Yields:
|
||||
str: The generated text for each prompt received. This method operates as a coroutine, yielding
|
||||
generated text back to the caller and awaiting new prompts for further generation.
|
||||
"""
|
||||
runner = self.runner
|
||||
mesh = runner.mesh
|
||||
max_len = runner.model.sequence_len
|
||||
@ -580,6 +790,18 @@ class InferenceRunner:
|
||||
def make_mesh(
|
||||
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
|
||||
) -> jax.sharding.Mesh:
|
||||
"""
|
||||
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
|
||||
|
||||
Parameters:
|
||||
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
|
||||
to how local devices are arranged and partitioned for computation.
|
||||
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
|
||||
different hosts, relevant in multi-host distributed setups.
|
||||
|
||||
Returns:
|
||||
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
|
||||
"""
|
||||
assert len(local_mesh_config) == 2
|
||||
assert len(between_hosts_config) == 2
|
||||
rank_logger.info("Detected %s devices in mesh", jax.device_count())
|
||||
@ -594,6 +816,32 @@ def make_mesh(
|
||||
|
||||
|
||||
def sample_from_model(server, prompt, max_len, temperature):
|
||||
"""
|
||||
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
|
||||
to interact with a generator-based server that manages the state of the model and performs the actual
|
||||
sampling. The server is expected to yield control back to this function, which then sends a request
|
||||
object containing the prompt and settings, and waits for the generated output.
|
||||
|
||||
The function initializes the server (if not already done), constructs the request with the provided
|
||||
parameters, and sends this request to the server. The sampled output, typically a continuation of the
|
||||
prompt, is then received from the server.
|
||||
|
||||
Parameters:
|
||||
- server (generator): The generator-based server that handles model inference. This server is expected
|
||||
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
|
||||
performing the token sampling.
|
||||
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
|
||||
encoded and fed into the model as part of the request.
|
||||
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
|
||||
This controls how long the model's output can be.
|
||||
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
|
||||
model's outputs more deterministic (and potentially more repetitive), while higher values increase
|
||||
diversity at the cost of coherence.
|
||||
|
||||
Returns:
|
||||
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
|
||||
by `max_len`.
|
||||
"""
|
||||
next(server)
|
||||
inp = Request(
|
||||
prompt=prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user