diff --git a/runners.py b/runners.py index 452c142..e537c40 100644 --- a/runners.py +++ b/runners.py @@ -48,6 +48,20 @@ TOP_K = 8 class SampleSettings(NamedTuple): + """ + A NamedTuple for storing settings used during the sampling process. + + Attributes: + - temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make + the model outputs more deterministic, while higher values increase diversity. + - nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of + token probabilities, effectively truncating low-probability tokens to focus sampling on a subset + of plausible tokens. + - mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0) + for sampling in each batch element. + - active (ArrayLike): A binary indicator for whether a given batch element is actively being used for + generation. Elements not in use are ignored in computations to save resources. + """ temperature: ArrayLike nucleus_p: ArrayLike mask: ArrayLike @@ -56,6 +70,15 @@ class SampleSettings(NamedTuple): class SampleOutput(NamedTuple): + """ + A NamedTuple for storing the output from the sampling process. + + Attributes: + - token_id (ArrayLike): The sampled token ID for each batch element. + - prob (ArrayLike): The probability associated with the sampled token for each batch element. + - top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element. + - top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element. + """ token_id: ArrayLike prob: ArrayLike top_k_token_ids: ArrayLike @@ -63,6 +86,19 @@ class SampleOutput(NamedTuple): def insert_slice(memory: Memory, slice, length, i): + """ + Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory + object with a new slice that includes updated steps for each layer based on the provided length. + + Parameters: + - memory (Memory): The original memory object to be updated. + - slice: A slice of the memory to be inserted, typically representing a new or modified piece of information. + - length (int): The step length to set for each KVMemory layer in the inserted slice. + - i (int): The index at which the slice is to be inserted into the memory. + + Returns: + - Memory: The updated memory object with the new slice inserted at the specified index. + """ slice = Memory( layers=[ KVMemory(layer.k, layer.v, step=jnp.array([length])) @@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i): def pad_to_size(x, size): + """ + Pads or truncates an array to a specified size. If the array is longer than the target size, it will be + truncated from the left. If it is shorter, it will be right-padded with zeros. + + Parameters: + - x (ArrayLike): The array to be padded or truncated. + - size (int): The target size for the array. + + Returns: + - ArrayLike: The array adjusted to the specified size. + """ if x.shape[0] > size: # Left truncate if the context is too long. x = x[-size:] @@ -82,7 +129,20 @@ def pad_to_size(x, size): def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: - """Performs nucleus filtering on logits.""" + """ + Performs nucleus filtering on logits. + Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This + nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving + the quality of generated text. + + Parameters: + - logits (jax.Array): The logits to filter. + - top_p (jax.Array): The cumulative probability threshold for filtering logits. + + Returns: + - jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them + from consideration during sampling. + """ assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_probs = jax.nn.softmax(sorted_logits) @@ -102,6 +162,21 @@ def sample_token( lm_outputs: LanguageModelOutput, settings: SampleSettings, ) -> SampleOutput: + """ + Samples a token from the language model outputs using specified settings, including temperature and + nucleus probability filtering. This function also computes the probability of the sampled token and + identifies the top-k tokens and their probabilities. + + Parameters: + - rngs (jax.random.PRNGKey): The random number generator state for sampling. + - lm_outputs (LanguageModelOutput): The outputs from the language model, including logits. + - settings (SampleSettings): The settings controlling the sampling process, including temperature, + nucleus probability, token masking, and active batch elements indicator. + + Returns: + - SampleOutput: The results of the sampling process, including the sampled token ID, its probability, + and the top-k tokens and their probabilities for each batch element. + """ # Expand the settings shape to match the logit shape. settings = SampleSettings( temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. @@ -135,6 +210,25 @@ def sample_token( @dataclass class ModelRunner: + """ + Manages the execution of a language model, including initialization, forward passes, and state management. + + Attributes: + - model (LanguageModelConfig): The configuration for the language model. + - bs_per_device (float): The batch size per device. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading. + - load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading. + - rng_seed (int): The initial random number generator seed. + - transform_forward (bool): Flag to determine whether to transform the forward function with Haiku. + - checkpoint_path (str): The path to the directory containing model checkpoints for loading. + + Methods: + - make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku. + - initialize: Initializes the model runner with data and mesh configuration, preparing it for execution. + - init: Initializes the model parameters using provided data. + - get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data. + - load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified. + """ model: LanguageModelConfig bs_per_device: float = 2.0 @@ -148,6 +242,18 @@ class ModelRunner: checkpoint_path: str = "" def make_forward_fn(self, mesh: Any): + """ + Creates and optionally transforms the forward function of the model. This method constructs a forward + function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the + method also transforms this function using Haiku's transform method to prepare it for JIT compilation + and execution on the mesh. + + Parameters: + mesh (Any): The mesh configuration to be used for distributing the computation across devices. + + Returns: + Callable: The forward function, potentially transformed by Haiku, ready for execution. + """ def forward(tokens): out = self.model.make(mesh=mesh)(tokens) return out, None @@ -162,6 +268,20 @@ class ModelRunner: local_mesh_config: tuple[int, int], between_hosts_config: tuple[int, int], ): + """ + Initializes the model runner with necessary configurations and data. This includes setting up the mesh + for distributed computation, calculating batch sizes based on the provided configuration, and preparing + the forward function for execution. + + Parameters: + init_data: Initial data used for model initialization. This is typically dummy data matching the + expected input format and dimensions of the model. + local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices + are organized locally for parallel computation. + between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a + distributed setup, important for multi-host environments. + + """ num_replicas = math.prod(between_hosts_config) self.model.initialize() self.model.fprop_dtype = jnp.bfloat16 @@ -191,12 +311,37 @@ class ModelRunner: self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) def init(self, rng: jax.Array, data) -> TrainingState: + """ + Initializes the model parameters using provided data and a random number generator state. This method + is only called when `transform_forward` is True, indicating that the forward function has been transformed + by Haiku and requires explicit initialization. + + Parameters: + rng (jax.Array): The random number generator state for initializing model parameters. + data: The initial data used for parameter initialization, usually including input tokens and + possibly target tokens for supervised learning tasks. + + Returns: + TrainingState: An instance of TrainingState containing the initialized model parameters. + """ assert self.transform_forward rng, init_rng = jax.random.split(rng) params = self.forward.init(init_rng, data["inputs"]) return TrainingState(params=params) def get_state_sharding(self, init_data): + """ + Determines the sharding configuration for model parameters based on initialization data. This method + evaluates the shape of model parameters during initialization to apply the partition rules specified + in the model's configuration, facilitating efficient distributed computation. + + Parameters: + init_data: The initial data used for model initialization, which helps determine the appropriate + sharding configuration based on model parameter shapes. + + Returns: + A sharding configuration that specifies how model parameters should be partitioned across the mesh. + """ assert self.transform_forward rng = jax.random.PRNGKey(self.rng_seed) rank_logger.info(f"partition rules: {self.model.partition_rules}") @@ -215,6 +360,20 @@ class ModelRunner: from_checkpoint: bool = True, init_fn: Optional[Callable] = None, ): + """ + Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This + method provides flexibility in starting the model from a known state or freshly initializing parameters + based on the model configuration and provided initial data. + + Parameters: + init_data: Initial data used for model parameter initialization if needed. + from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint. + init_fn (Optional[Callable]): An optional initialization function that overrides the default + initialization method if provided. + + Returns: + The model state, either loaded from a checkpoint or newly initialized. + """ rng = jax.random.PRNGKey(self.rng_seed) if not self.checkpoint_path or not from_checkpoint: @@ -251,6 +410,16 @@ class ModelRunner: @dataclass class Request: + """ + Data class for storing information about a single sampling request to the model. + + Attributes: + prompt (str): The text prompt to generate text from. + temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs. + nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens. + rng_seed (int): Seed for the random number generator to ensure reproducibility of the results. + max_len (int): The maximum length of the generated text sequence. + """ prompt: str temperature: float nucleus_p: float @@ -260,6 +429,24 @@ class Request: @dataclass class InferenceRunner: + """ + Manages inference operations for generating text based on prompts, including initializing the model and tokenizer, + prefilling the model memory, and running the actual text generation. + + Attributes: + name (str): A name identifier for the inference runner. + runner (Any): An instance of ModelRunner that manages the underlying language model. + load (str): A path to the model checkpoint for loading. + tokenizer_path (str): Path to the SentencePiece tokenizer model file. + local_mesh_config (Tuple[int, int]): Configuration for the local device mesh. + between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup. + pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts. + + Methods: + get_pad_bucket: Determines the appropriate padding size for a given input length. + initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes. + run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps. + """ name: str runner: Any load: str @@ -269,10 +456,25 @@ class InferenceRunner: pad_sizes: tuple[int] = (1024,) def get_pad_bucket(self, size): + """ + Determines the appropriate padding size for a given input length. This method uses the predefined + pad sizes to find the smallest pad size that is larger than or equal to the given size. + + Parameters: + size (int): The length of the input sequence to be padded. + + Returns: + int: The selected pad size from the predefined pad sizes that best fits the input length. + """ i = bisect.bisect_left(self.pad_sizes, size) return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] def initialize(self): + """ + Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling + model functions for different input sizes, and loading or initializing model parameters. This method + ensures that the inference runner is ready to generate text based on prompts. + """ runner = self.runner self.runner.transform_forward = True dummy_data = dict( @@ -440,7 +642,15 @@ class InferenceRunner: ) def run(self): - """Generator that accepts prompts.""" + """ + Generator method that accepts prompts and manages the text generation process. This method handles + tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to + generate and return the generated text. + + Yields: + str: The generated text for each prompt received. This method operates as a coroutine, yielding + generated text back to the caller and awaiting new prompts for further generation. + """ runner = self.runner mesh = runner.mesh max_len = runner.model.sequence_len @@ -580,6 +790,18 @@ class InferenceRunner: def make_mesh( local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] ) -> jax.sharding.Mesh: + """ + Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices. + + Parameters: + local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds + to how local devices are arranged and partitioned for computation. + between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between + different hosts, relevant in multi-host distributed setups. + + Returns: + jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation. + """ assert len(local_mesh_config) == 2 assert len(between_hosts_config) == 2 rank_logger.info("Detected %s devices in mesh", jax.device_count()) @@ -594,6 +816,32 @@ def make_mesh( def sample_from_model(server, prompt, max_len, temperature): + """ + Samples tokens from a trained model given a prompt and sampling settings. This function is designed + to interact with a generator-based server that manages the state of the model and performs the actual + sampling. The server is expected to yield control back to this function, which then sends a request + object containing the prompt and settings, and waits for the generated output. + + The function initializes the server (if not already done), constructs the request with the provided + parameters, and sends this request to the server. The sampled output, typically a continuation of the + prompt, is then received from the server. + + Parameters: + - server (generator): The generator-based server that handles model inference. This server is expected + to manage the lifecycle of the model, including loading parameters, setting up the environment, and + performing the token sampling. + - prompt (str): The initial text prompt to which the model generates a continuation. This prompt is + encoded and fed into the model as part of the request. + - max_len (int): The maximum length of the generated sequence, including the length of the prompt. + This controls how long the model's output can be. + - temperature (float): A parameter controlling the randomness of the output. Lower values make the + model's outputs more deterministic (and potentially more repetitive), while higher values increase + diversity at the cost of coherence. + + Returns: + - str: The generated text as a continuation of the provided prompt, up to a maximum length specified + by `max_len`. + """ next(server) inp = Request( prompt=prompt,