diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..4288f31 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit @contextlib.contextmanager def copy_to_shm(file: str): + """ + Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm), + yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path + to the temporary file, and cleans up by removing the temporary file after use. + + Parameters: + - file (str): The path to the file to be copied. + + Yields: + - str: The path to the file in shared memory. + """ if file.startswith("/dev/shm/"): # Nothing to do, the file is already in shared memory. yield file @@ -58,6 +69,17 @@ def copy_to_shm(file: str): @contextlib.contextmanager def copy_from_shm(file: str): + """ + Context manager for copying a file from shared memory to a specified path. It creates a temporary file + in shared memory, yields the path to this temporary file for operations, and then copies the temporary + file to the specified path, cleaning up the temporary file afterwards. + + Parameters: + - file (str): The target path for the file to be copied to from shared memory. + + Yields: + - str: The path to the temporary file in shared memory. + """ tmp_dir = "/dev/shm/" fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) try: @@ -69,19 +91,48 @@ def copy_from_shm(file: str): def fast_unpickle(path: str) -> Any: + """ + Unpickles and loads an object from a file, optionally using shared memory for faster access. + + Parameters: + - path (str): The path to the pickle file to load. + + Returns: + - Any: The object loaded from the pickle file. + """ with copy_to_shm(path) as tmp_path: with open(tmp_path, "rb") as f: return pickle.load(f) def fast_pickle(obj: Any, path: str) -> None: + """ + Pickles and saves an object to a file, optionally using shared memory for faster access. + + Parameters: + - obj (Any): The object to be pickled and saved. + - path (str): The path where the pickle file should be saved. + """ with copy_from_shm(path) as tmp_path: with open(tmp_path, "wb") as f: pickle.dump(obj, f) def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): - """Loads a set of arrays.""" + """ + Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with + arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved + in a separate file. + + Parameters: + - shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load. + - directory (str): The directory from which to load the tensors. + - mesh_config: Configuration for data parallelism across processes or devices. + - tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded. + + Returns: + - List of numpy arrays or zeros depending on the process's role in data parallelism. + """ pool = ThreadPoolExecutor(max_workers=32) fs = list() num_tensors = 0 @@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def path_tuple_to_string(path: tuple) -> str: + """ + Converts a tuple representing a path in a nested structure to a string. This function is specifically + used for handling paths within the structure of saved states or parameters, making it easier to + identify and manipulate specific elements. + + Parameters: + - path (tuple): A tuple representing a path in a nested structure. + + Returns: + - str: A string representation of the path, suitable for logging or identification. + """ pieces = [] for elem in path: if isinstance(elem, jax.tree_util.DictKey): @@ -124,6 +186,19 @@ def get_load_path_str( load_rename_rules: Optional[list[tuple[str, str]]] = None, load_exclude_rules: Optional[list[str]] = None, ) -> Optional[str]: + """ + Determines the load path for a given initial path string, applying exclusion and renaming rules. This + function is used in the context of loading states or parameters from files, allowing for flexible + mapping or exclusion based on pattern matching. + + Parameters: + - init_path_str (str): The initial path string to process. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching. + - load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading. + + Returns: + - Optional[str]: The processed load path string, or None if excluded. + """ # Exclusion if load_exclude_rules is not None: for search_pattern in load_exclude_rules: @@ -148,6 +223,21 @@ def replace_with_load_state( load_exclude_rules: Optional[list[str]] = None, mesh_config: tuple = (1, 1), ) -> Any: + """ + Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion + rules. This function supports conditional inclusion and transformation of state elements based on complex + criteria, facilitating flexible state restoration. + + Parameters: + - init_state (Any): The initial state before replacement. + - load_state (Any): The state from which to load replacements. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths. + - load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading. + - mesh_config (tuple): Configuration for data parallelism. + + Returns: + - Any: The initial state with elements replaced from the load state. + """ flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} @@ -186,6 +276,23 @@ def restore( state_sharding, init_state: Optional[Any] = None, ) -> Any: + """ + Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding + configurations. This function is designed for restoring model states from disk with support for distributed + environments, handling the intricacies of partitioning and host-specific configurations. + + Parameters: + - checkpoint_path (str): The path to the checkpoint directory. + - state_shapes (Any): The expected shapes of the state to restore. + - mesh: The mesh configuration for distributed environments. + - between_hosts_config: Configuration for data exchange between hosts. + - params_only (bool): Whether to restore parameters only, excluding other state parts. + - state_sharding: Sharding configuration for the state. + - init_state (Optional[Any]): The initial state to which the checkpoint data is applied. + + Returns: + - Any: The restored state, potentially sharded across the distributed environment. + """ ckpt_path = os.path.join(checkpoint_path, "ckpt-0") rank_logger.info("Loading checkpoint at {}".format(ckpt_path)) diff --git a/model.py b/model.py index 87d700d..48ffcf2 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Grok-1 + +This architecture is designed for language modeling and autoregressive sequence generation, incorporating advanced features such as: + +- 8-bit Quantized Weights: Implements quantization for model parameters to reduce the memory footprint and potentially increase computational efficiency. +- Sharding and Distributed Computation: Utilizes JAX's capabilities for distributed computation across devices, optimizing parallel processing and memory usage. +- Memory Caching for Autoregressive Decoding: Features a mechanism for caching keys and values in attention layers, enhancing efficiency in sequence generation tasks. + +Core Components: +- Multi-Head Attention (MHA): Custom implementation, including rotary positional embeddings for implicit sequence position information. +- Mixture of Experts (MoE): Allows routing inputs to different "expert" networks based on the input data, increasing model capacity and expressiveness. +- Feed-Forward Networks (FFNs): Defines networks with adjustable sizes and support for quantized weights and custom linear layers with sharding. + +Training and Inference Workflow: +- Manages the model's parameters, caching mechanisms, and layer-wise states efficiently during both training and inference. +- Implements detailed sharding strategies for optimizing the distribution of computation and storage across multiple devices. + +Advanced Features: +- Custom Layer Normalization: Adopts RMSNorm for stabilizing the training of deep networks. +- Dynamic and Static Sharding: Offers flexibility in data and model parallelism, allowing dynamic adjustment of sharding constraints. + +Efficiency and Scalability: +- Efficiently manages data flow and computation, minimizing unnecessary data replication and movement across devices. +- Designed with scalability in mind, providing a foundation for training complex models on massive datasets. + +This architecture leverages JAX for high-performance numerical computing and automatic differentiation, alongside Haiku for modular and flexible deep learning model construction. + +Data Flow Through the Training Process: + +1. Input Preparation: + - The process begins with the preparation of input data, which typically involves tokenizing text data into numerical tokens that represent words or subwords in a vocabulary. + - Tokens are then batched and padded to ensure consistent sequence lengths within each batch, forming the initial input tensor for the model. + +2. Embedding Layer: + - The input tokens are passed through an embedding layer, transforming each token into a high-dimensional vector. This layer may utilize pre-trained embeddings or learn embeddings during training. + - Positional encodings or embeddings are added to these vectors to incorporate sequence position information. + +3. Transformer Layers: + - The sequence of embedding vectors is processed through multiple Transformer layers, each consisting of the following sub-layers: + a. Multi-Head Attention (MHA): Each layer computes self-attention for its input, allowing each position in the sequence to attend to all positions in the previous layer’s output. + b. Feed-Forward Network (FFN): After attention, the output for each position passes through a feed-forward network. FFNs are identical for different positions but have different parameters from layer to layer. + + - Between each sub-layer, residual connections followed by layer normalization are applied. This helps in stabilizing the training of deep networks. + +4. Caching and Memory: + - For autoregressive tasks, where the model generates sequences one token at a time, keys and values computed during the attention operations are cached. This mechanism allows reusing these computations in subsequent steps, reducing the computational load. + +5. Output Layer: + - The output from the final Transformer layer is passed through a linear layer or a decoder, transforming the high-dimensional representations back into the vocabulary space to produce logits for each token in the sequence. + +6. Loss Calculation and Backpropagation: + - The logits are compared against the true token sequences using a suitable loss function (e.g., cross-entropy for language modeling tasks). The loss quantifies the model's prediction accuracy. + - Based on the loss, gradients are computed for each parameter in the model using backpropagation. These gradients indicate how each parameter should be adjusted to minimize the loss. + +7. Parameter Update: + - Model parameters are updated using an optimization algorithm (e.g., Adam, SGD) and the computed gradients. This step adjusts the model’s weights to improve its predictions on the next iteration. + +8. Iteration and Convergence: + - Steps 1 through 7 are repeated for multiple epochs over the training dataset until the model's performance on a validation set converges or begins to degrade, indicating potential overfitting. + +This structured approach, combined with the Transformer architecture's capability to model complex dependencies in data, enables effective training of models for tasks such as language understanding, translation, and generation. +""" + import functools import logging import re @@ -36,6 +100,13 @@ rank_logger = logging.getLogger("rank") @dataclass class QuantizedWeight8bit: + """ + Represents an 8-bit quantized weight for neural network parameters. + + Attributes: + weight (jnp.array): The quantized weights. + scales (jnp.array): The scale factors used for quantization. + """ weight: jnp.array scales: jnp.array @@ -52,13 +123,28 @@ tree_util.register_pytree_node( class TrainingState(NamedTuple): - """Container for the training state.""" + """Container for the training state, encapsulating model parameters. + + Attributes: + params (hk.Params): The parameters of the model. + """ params: hk.Params def _match(qs, ks): - """Return True if regexes in qs match any window of strings in tuple ks.""" + """ + Determines if a sequence of regex patterns (qs) matches any contiguous subsequence of strings (ks). + This utility function is often used for matching parameter names or paths in a hierarchical structure. + + Args: + qs (Sequence[str]): A sequence of regex patterns to match. + ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched. + + Returns: + bool: True if every pattern in qs has a corresponding match in a contiguous subsequence of ks, + otherwise False. + """ # compile regexes and force complete match qts = tuple(map(lambda x: re.compile(x + "$"), qs)) for i in range(len(ks) - len(qs) + 1): @@ -69,6 +155,19 @@ def _match(qs, ks): def with_sharding_constraint(x, constraint): + """ + Applies a sharding constraint to a JAX array. This function is used in SPMD programs to hint how the + data should be partitioned across devices. If a physical mesh is not available, it simply returns the + original array. + + Args: + x (jax.Array): The array to apply the sharding constraint to. + constraint (PartitionSpec): The sharding constraint to apply. + + Returns: + jax.Array: The array with the sharding constraint applied, affecting its distribution across devices + in distributed computation setups. + """ if jax.experimental.maps.thread_resources.env.physical_mesh.empty: return x else: @@ -76,6 +175,17 @@ def with_sharding_constraint(x, constraint): def cast_bfloat16(x): + """ + Casts the input array to bfloat16 type if it is of floating-point type. This operation is often used to + reduce memory consumption and potentially increase computation speed by using lower precision. + + Args: + x (jax.Array): The input array. + + Returns: + jax.Array: The array cast to bfloat16 if the original array was floating-point; otherwise, the array + is returned unchanged. + """ if x.dtype.kind == "f": return x.astype(jnp.bfloat16) else: @@ -83,6 +193,18 @@ def cast_bfloat16(x): def ffn_size(emb_size, widening_factor): + """ + Calculates the size of the feed-forward network (FFN) based on the embedding size and a widening factor. + + The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations. + + Args: + emb_size (int): The size of the embeddings. + widening_factor (float): The factor by which to widen the FFN relative to the embedding size. + + Returns: + int: The adjusted size of the FFN. + """ _ffn_size = int(widening_factor * emb_size) * 2 // 3 _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8 logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}") @@ -90,6 +212,20 @@ def ffn_size(emb_size, widening_factor): def apply_rules(rules): + """ + Constructs a function to apply a set of sharding rules for transformer parameters. + + This function is used to determine the sharding specifications for model parameters based on their roles + and positions within the model architecture. + + Args: + rules (List[Tuple[Sequence[str], PartitionSpec]]): A list of tuples where each tuple contains a sequence + of strings representing the parameter path and the corresponding `PartitionSpec` to apply. + + Returns: + Callable: A function that takes a parameter path and returns the appropriate `PartitionSpec` based + on the provided rules. + """ def _apply_rules(path, value): del value # Unused. @@ -176,6 +312,14 @@ TOP_K = 8 class KVMemory(NamedTuple): + """ + Represents key-value memory slots for a transformer layer, supporting efficient autoregressive decoding by caching past computed keys and values. + + Attributes: + k (Optional[jax.Array]): Cached keys, shaped as [batch_size, sequence_len, num_kv_heads, key_size]. + v (Optional[jax.Array]): Cached values, shaped as [batch_size, sequence_len, num_kv_heads, key_size]. + step (Optional[jax.Array]): The current decoding step, indicating how many positions have been generated, shaped as [batch_size]. + """ k: Optional[jax.Array] v: Optional[jax.Array] step: Optional[jax.Array] @@ -190,6 +334,21 @@ def init_layer_memories( step: Optional[jax.Array] = None, dtype=jnp.bfloat16, ): + """ + Initializes layer memories for each transformer layer, providing a mechanism for efficient sequence generation by caching keys and values. + + Args: + batch_size (int): The number of sequences being processed in parallel. + sequence_len (int): The length of the sequences for which memory is allocated. + num_kv_heads (int): The number of key-value pairs per head in the attention mechanism. + key_size (int): The size of each key (and value) in the attention mechanism. + num_layers (int): The number of transformer layers for which memory is initialized. + step (Optional[jax.Array]): The initial decoding step for each sequence in the batch. Defaults to None, indicating no prior steps. + dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency. + + Returns: + List[KVMemory]: A list of initialized KVMemory instances for each layer in the transformer model. + """ return [ KVMemory( k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype), @@ -201,11 +360,28 @@ def init_layer_memories( class Memory(NamedTuple): + """ + A named tuple representing the complete memory state of a transformer model, encapsulating key-value memory slots for all layers. + + Attributes: + layers (List[KVMemory]): A list of KVMemory instances, one for each layer in the transformer model. + """ # Self-attention key/value cache. layers: List[KVMemory] class Router(hk.Module): + """ + A module for routing inputs to experts in a Mixture of Experts (MoE) layer. + + Attributes: + num_selected_experts (int): Number of experts to select for each input. + data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding. + model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding. + shard_activations (bool): If True, shard activations according to the data and model axes. + mesh (Any): The SPMD mesh for parallel computation. + name (str): The name of the module. + """ def __init__( self, num_selected_experts: int, @@ -215,6 +391,17 @@ class Router(hk.Module): mesh: Any = None, name: str = "router", ): + """ + Initializes a router for directing inputs to experts in a Mixture of Experts (MoE) layer. + + Args: + num_selected_experts (int): The number of experts to select for each input. + data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded. + model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded. + shard_activations (bool): Indicates whether to shard activations according to the data and model axes. + mesh (Any): The SPMD mesh object for parallel computation. + name (str): An optional name for the module. + """ super().__init__(name) self.shard_activations = shard_activations self.data_axis = data_axis @@ -225,6 +412,20 @@ class Router(hk.Module): def compute_routing_prob( self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int ): + """ + Computes the routing probabilities for each input to be directed to each expert. + + This internal method calculates the logits that determine how inputs are distributed among the experts, + based on learned criteria. + + Args: + inputs (jax.Array): Input data to be routed. + padding_mask (Optional[jax.Array]): Mask indicating which inputs are padding and should not be considered. + num_experts (int): The total number of experts available. + + Returns: + A tuple containing the routing probabilities, logits, and a dummy placeholder for compatibility. + """ return self._compute_routing_prob(inputs, padding_mask, num_experts) @hk.transparent @@ -234,6 +435,19 @@ class Router(hk.Module): padding_mask: Optional[jax.Array], num_experts: int, ): + """ + Computes the routing probabilities for directing inputs to the appropriate experts. + + Args: + inputs (jax.Array): Input data to be routed, shaped as [batch_size, ..., input_dim]. + padding_mask (Optional[jax.Array]): An optional mask indicating padded elements in the input, + shaped as [batch_size, seq_length], where padded positions are False. + num_experts (int): The total number of experts available for routing. + + Returns: + A tuple containing routing probabilities, routing logits, and a dummy value for compatibility, + shaped as ([batch_size, seq_length, num_experts], [batch_size, seq_length, num_experts], int). + """ # Using fp32 for the routing prob computation. inputs = jax.lax.convert_element_type(inputs, jnp.float32) @@ -270,6 +484,19 @@ class Router(hk.Module): class MoELayer(hk.Module): + """ + A module implementing a Mixture of Experts (MoE) layer. + + Attributes: + num_experts (int): The number of experts in the MoE layer. + layer_fn (Callable): The function to be applied by each expert. + router (Router): The router that routes inputs to experts. + mesh (Any): The SPMD mesh for parallel computation. + shard_activations (bool): If True, shard activations across data and model axes. + data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding. + model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding. + name (Optional[str]): The name of the module. + """ def __init__( self, num_experts: int, @@ -281,6 +508,19 @@ class MoELayer(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model", name: Optional[str] = "moe", ): + """ + Initializes a Mixture of Experts layer with specified configuration. + + Args: + num_experts (int): The total number of experts in the MoE layer. + layer_fn (Callable): The function defining the computation performed by each expert. + router (Router): The router that directs inputs to selected experts. + mesh (Any): The optional SPMD mesh for parallel computation. + shard_activations (bool): Whether to shard activations across distributed resources. + data_axis (Union[str, Tuple[str, ...]]): Specifies how data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Specifies how model parameters are sharded. + name (Optional[str]): An optional name for the module. + """ super().__init__(name) self.num_experts = num_experts self.layer_fn = layer_fn @@ -292,6 +532,18 @@ class MoELayer(hk.Module): @hk.transparent def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None): + """ + Handles the inference call to the MoE layer, distributing inputs to selected experts based on routing. + + Args: + inputs (jax.Array): Input data to be processed, shaped as [batch_size, seq_length, input_dim]. + padding_mask (Optional[jax.Array]): An optional mask for the inputs, where False indicates + positions that should not be processed (e.g., padding), shaped as [batch_size, seq_length]. + + Returns: + jax.Array: The processed outputs after passing through the selected experts, shaped as + [batch_size, seq_length, output_dim]. + """ routing_probs, _, _ = self.router.compute_routing_prob( inputs, padding_mask, self.num_experts ) @@ -401,24 +653,66 @@ class MoELayer(hk.Module): class MHAOutput(NamedTuple): - """Outputs of the multi-head attention operation.""" + """ + Represents the output of the Multi-Head Attention (MHA) operation. + + Attributes: + embeddings (jax.Array): The output embeddings from the MHA layer. + memory (Any): The updated memory state post-attention operation. + """ embeddings: jax.Array memory: Any class DecoderOutput(NamedTuple): + """ + Encapsulates the output from a decoder layer within the transformer model, including the transformed embeddings and any updated memory state. + + Attributes: + embeddings (jax.Array): The embeddings produced by the decoder layer, shaped as [batch_size, seq_length, embedding_dim]. + memory (Any): The updated memory state after processing by the decoder layer, useful for autoregressive decoding tasks. + """ embeddings: jax.Array memory: Any class TransformerOutput(NamedTuple): + """ + Represents the final output from the transformer model, including the final set of embeddings and any memory states that have been updated through the model's layers. + + Attributes: + embeddings (jax.Array): The final output embeddings from the transformer, shaped as [batch_size, seq_length, embedding_dim]. + memory (Any): The final memory state of the model after all transformer layers have been applied. + """ embeddings: jax.Array memory: Any @dataclass class TransformerConfig: + """ + Configuration class for setting up a Transformer model's architecture and its specific parameters. + + This class defines key architectural features of the transformer, including the size of embeddings, + the dimensionality of keys and values in the attention mechanism, the number of layers, and more. + It also includes configurations for advanced features like Mixture of Experts (MoE) and activation sharding. + + Attributes: + emb_size (int): The size of the embedding vectors. + key_size (int): The size of the key (and query) vectors in the attention mechanism. + num_q_heads (int): The number of heads in the query part of the multi-head attention mechanism. + num_kv_heads (int): The number of heads for keys and values in the multi-head attention. + num_layers (int): The total number of layers in the transformer model. + vocab_size (int): The size of the vocabulary that the model can understand. + widening_factor (float): The factor by which the dimensionality of the feed-forward networks is widened relative to the embedding size. + attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism, for controlling its magnitude. + shard_activations (bool): Whether to shard activations across devices for parallel processing. + num_experts (int): The number of experts in the Mixture of Experts (MoE) layer, if used. + num_selected_experts (int): The number of experts selected for each input in the MoE layer. + data_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which model parameters are sharded for distributed computation. + """ emb_size: int key_size: int num_q_heads: int @@ -503,6 +797,10 @@ def make_attention_mask( dtype: Any = jnp.bfloat16, ): """Mask-making helper for attention weights. + + Creates an attention mask to specify which tokens in the key sequences can be attended to by each token in the query sequences. + + This utility is used in attention mechanisms to control the visibility of tokens, for purposes such as preventing future tokens from being attended to in autoregressive models. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the attention weights will be `[batch..., heads, len_q, len_kv]` and this @@ -523,6 +821,20 @@ def make_attention_mask( class Linear(hk.Linear): + """ + Extends Haiku's Linear layer with optional sharding for use in distributed settings. + + This class allows specifying a `PartitionSpec` to shard the linear layer's weights across devices, + which can be beneficial in large-scale models processed over multiple devices or nodes. + + Args: + output_size (int): The size of the output dimension. + with_bias (bool, optional): Whether to include a bias term. Defaults to True. + sharding (Optional[P], optional): The sharding specification for distributing the layer's parameters. + mesh (Any, optional): The SPMD mesh for parallel computation. Defaults to None. + name (Optional[str], optional): An optional name for this module. Defaults to None. + shard_axis (int, optional): The axis along which to shard the input data. Defaults to 0. + """ def __init__( self, output_size: int, @@ -545,7 +857,17 @@ class Linear(hk.Linear): self, inputs: jax.Array, ) -> jax.Array: - """Computes a linear transform of the input.""" + """ + Computes a linear transform of the input data. + + This method computes the matrix multiplication between the inputs and the layer's weight matrix, optionally adding a bias term. + + Args: + inputs (jax.Array): The input tensor to be transformed, shaped as [batch_size, ..., input_features]. + + Returns: + jax.Array: The transformed tensor, shaped as [batch_size, ..., output_features]. + """ fprop_dtype = inputs.dtype if not inputs.shape: @@ -585,7 +907,21 @@ class Linear(hk.Linear): class RMSNorm(hk.RMSNorm): + """ + Implements Root Mean Square Layer Normalization. + This variant of layer normalization scales inputs by the root mean square of their elements, optionally + including a learnable scaling factor. It supports specifying a `PartitionSpec` for sharding the scale + parameters across devices in distributed settings. + + Args: + axis (Union[int, Sequence[int], slice]): The dimensions to normalize over. + eps (float, optional): A small constant added to the denominator to improve numerical stability. + Defaults to 1e-5. + name (Optional[str], optional): An optional name for this module. Defaults to None. + create_scale (bool, optional): Whether to include a learnable scaling factor. Defaults to True. + sharding (Optional[P], optional): The sharding specification for the scale parameter. Defaults to None. + """ def __init__( self, axis: Union[int, Sequence[int], slice], @@ -598,6 +934,17 @@ class RMSNorm(hk.RMSNorm): self.sharding = sharding def __call__(self, inputs: jax.Array): + """ + Applies RMS normalization to the input tensor. + + This method normalizes the inputs by their root mean square value, optionally scaling the result by a learnable parameter to adjust the representation scale. + + Args: + inputs (jax.Array): The input tensor to be normalized, shaped as [batch_size, ..., features]. + + Returns: + jax.Array: The RMS-normalized tensor, maintaining the input shape. + """ fprop_dtype = inputs.dtype param_shape = (inputs.shape[-1],) if self.create_scale: @@ -633,12 +980,19 @@ def rotate_half( class RotaryEmbedding(hk.Module): - """Applies rotary embeddings (RoPE) to the input sequence tensor, + """ + Implements Rotary Position Embedding (RoPE) to the input sequence tensor, as described in https://arxiv.org/abs/2104.09864. - Attributes: - dim (int): Dimensionality of the feature vectors - base_exponent (int): Base exponent to compute embeddings from + RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their + position in the sequence. This approach is designed to preserve the relative positional information across + different sequence lengths and tasks. + + Args: + dim (int): The dimensionality of the embeddings to be rotated, must be even. + name (Optional[str], optional): An optional name for this module. Defaults to None. + base_exponent (int, optional): The base of the exponent used to calculate rotary frequencies. + Defaults to 10000. """ def __init__( @@ -660,6 +1014,21 @@ class RotaryEmbedding(hk.Module): const_position: Optional[int] = None, t: Optional[jax.Array] = None, ) -> jax.Array: + """ + Applies Rotary Position Embedding (RoPE) to the input embeddings. + + This method dynamically encodes positional information by applying a rotation based on the position in the sequence, enhancing the model's ability to capture sequential dependencies. + + Args: + x (jax.Array): Input embeddings to apply RoPE to, shaped as [batch_size, seq_length, embedding_dim]. + seq_dim (int): The dimension index of the sequence length in the input tensor. + offset (jax.Array): The offset to apply to the position indices, useful for continuation of sequences across batches. + const_position (Optional[int]): A constant position value to use for all positions, instead of a range. + t (Optional[jax.Array]): Explicit tensor of position values, shaped as [seq_length]. + + Returns: + jax.Array: The input embeddings with RoPE applied, maintaining the input shape. + """ fprop_dtype = x.dtype # Compute the per-dimension frequencies exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32) @@ -692,6 +1061,23 @@ class RotaryEmbedding(hk.Module): class MultiHeadAttention(hk.Module): + """ + Implements the Multi-Head Attention mechanism, a key component of Transformer architectures. + + This module allows each token in the input sequence to attend to all tokens in the key and value sequences, with multiple "heads" learning different attention patterns. + + Attributes: + num_q_heads (int): The number of heads for the queries. + num_kv_heads (int): The number of heads for the keys and values. + key_size (int): The size of the key vectors. + with_bias (bool): Whether to include bias terms in the linear transformations. + value_size (Optional[int]): The size of the value vectors, defaults to the key size if not specified. + model_size (Optional[int]): The size of the output dimension from the attention mechanism, defaults to the total size of all heads if not specified. + attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded for distributed computation. + name (Optional[str]): An optional name for the module. + """ def __init__( self, num_q_heads: int, @@ -706,6 +1092,21 @@ class MultiHeadAttention(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model", name: Optional[str] = None, ): + """ + Initializes the Multi-Head Attention module with the provided configuration parameters. + + Args: + num_q_heads (int): Number of query heads for the multi-head attention mechanism. + num_kv_heads (int): Number of key/value heads for the multi-head attention mechanism. + key_size (int): Dimensionality of key vectors in the attention mechanism. + with_bias (bool): Whether to include a bias term in the attention weight computation. + value_size (Optional[int]): Dimensionality of value vectors, defaults to key size if not specified. + model_size (Optional[int]): Overall size of the model's output dimension. + attn_output_multiplier (float): Multiplier for scaling the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): Data sharding axis names for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Model parameter sharding axis names for distributed computation. + name (Optional[str]): An optional name for the module. + """ super().__init__(name=name) self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads @@ -726,6 +1127,22 @@ class MultiHeadAttention(hk.Module): kv_memory: Optional[KVMemory] = None, mesh: Any = None, ) -> MHAOutput: + """ + Computes the multi-head attention over the input queries, keys, and values. + + Args: + query (jax.Array): Query vectors, shaped as [batch_size, seq_length, model_dim]. + key (Optional[jax.Array]): Key vectors. If None, uses query as key. + value (Optional[jax.Array]): Value vectors. If None, uses query as value. + mask (Optional[jax.Array]): An optional mask to prevent attention to certain positions, + shaped as [batch_size, 1, seq_length, seq_length]. + kv_memory (Optional[KVMemory]): Optional memory for keys and values to support efficient + autoregressive decoding. + mesh (Any): The SPMD mesh for parallel computation, if applicable. + + Returns: + MHAOutput: A named tuple containing the output embeddings and updated memory. + """ # In shape hints below, we suppress the leading dims [...] for brevity. # Hence e.g. [A, B] should be read in every case as [..., A, B]. sequence_length = query.shape[1] @@ -900,6 +1317,22 @@ class MultiHeadAttention(hk.Module): name: Optional[str] = None, mesh: Any = None, ) -> jax.Array: + """ + Projects input embeddings into multiple head spaces for queries, keys, or values. + + This internal method creates separate embeddings for each attention head by applying a linear transformation, + allowing the multi-head attention mechanism to explore different subspace relationships in the data. + + Args: + x (jax.Array): Input tensor to project, shaped as [batch_size, seq_length, embedding_dim]. + head_size (int): The dimensionality of each head's subspace. + num_heads (int): The number of heads to project into. + name (Optional[str]): A name for the operation, distinguishing between queries, keys, and values. + sharding (Optional[PartitionSpec]): The sharding specification for distributing computation across devices. + + Returns: + jax.Array: Projected embeddings for multiple heads, shaped as [batch_size, seq_length, num_heads, head_size]. + """ y = Linear( num_heads * head_size, with_bias=False, @@ -913,7 +1346,20 @@ class MultiHeadAttention(hk.Module): @dataclass class MHABlock(hk.Module): - """A MHA Block""" + """ + A specialized module encapsulating the Multi-Head Attention (MHA) operation within a transformer model. + + This module orchestrates the application of MHA, including the computation of queries, keys, and values, and the subsequent attention and aggregation operations. + + Attributes: + num_q_heads (int): Number of heads for the query part of the MHA. + num_kv_heads (int): Number of heads for the key and value parts of the MHA. + key_size (int): Size of the keys (and queries) in the attention mechanism. + attn_output_multiplier (float): Scaling factor applied to the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded for distributed computation. + mesh (Any): The SPMD mesh for parallel computation. + """ num_q_heads: int num_kv_heads: int @@ -930,6 +1376,20 @@ class MHABlock(hk.Module): mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1] layer_memory: Optional[KVMemory], ) -> MHAOutput: + """ + Processes inputs through a Multi-Head Attention block. + + This method applies multi-head attention to the inputs using the provided mask for attention scoring, + and optionally utilizes cached memory for keys and values to enhance efficiency in autoregressive models. + + Args: + inputs (jax.Array): Input embeddings, shaped as [batch_size, seq_length, embedding_dim]. + mask (jax.Array): Attention mask, shaped as [batch_size, 1, seq_length, seq_length], to control visibility between tokens. + layer_memory (Optional[KVMemory]): Cached keys and values from previous steps for efficient attention computation in autoregressive decoding. + + Returns: + MHAOutput: The output from the multi-head attention block, including the transformed embeddings and updated memory. + """ _, _, model_size = inputs.shape assert mask.ndim == 4, f"shape: {mask.shape}" assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape) @@ -962,6 +1422,19 @@ class MHABlock(hk.Module): @dataclass class DenseBlock(hk.Module): + """ + Implements a dense (fully connected) block within a transformer layer, typically following multi-head attention. + + This block applies one or more linear transformations to its inputs, often including non-linear activations and potentially other operations like dropout for regularization. + + Attributes: + num_q_heads (int): The number of heads for the query in the preceding MHA layer. + num_kv_heads (int): The number of key/value pairs per head in the MHA layer. + key_size (int): The size of the keys in the MHA layer. + widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased. + sharding_constraint (bool): Whether to apply a sharding constraint for distributed computation. + mesh (Any): The SPMD mesh for parallel computation. + """ num_q_heads: int num_kv_heads: int key_size: int @@ -974,6 +1447,18 @@ class DenseBlock(hk.Module): self, inputs: jax.Array, # [B, T, D] ) -> jax.Array: # [B, T, D] + """ + Applies a series of dense transformations to the inputs. + + This method constitutes the feedforward network part of a transformer layer, applying linear transformations + followed by non-linear activations to model complex data relationships beyond what attention mechanisms capture. + + Args: + inputs (jax.Array): Input embeddings from the previous layer or block, shaped as [batch_size, seq_length, embedding_dim]. + + Returns: + jax.Array: The output embeddings after applying the dense block transformations, maintaining the input shape. + """ _, _, model_size = inputs.shape h_v = Linear( ffn_size( @@ -1009,7 +1494,28 @@ class DenseBlock(hk.Module): @dataclass class DecoderLayer(hk.Module): - """A transformer stack.""" + """ + Represents a single layer in the decoder stack of a transformer model. This layer processes input embeddings through + a multi-head attention mechanism followed by position-wise feed-forward networks, with normalization and skip connections + applied as per the standard transformer architecture. + + Attributes: + num_q_heads (int): Number of query heads for multi-head attention. + num_kv_heads (int): Number of key/value pairs per attention head. + key_size (int): Size of keys in the attention mechanism. + num_layers (int): Total number of transformer layers. + num_experts (int): Number of experts in the Mixture of Experts layer, if used. + layer_index (Optional[int]): Index of this layer within the overall model; used for layer-specific configurations or logging. + num_selected_experts (int): Number of experts selected for each input in the MoE layer. + widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased. + name (Optional[str]): An optional name for the layer. + data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding in distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding in distributed computation. + shard_activations (bool): Whether activations should be sharded across devices. + attn_output_multiplier (float): Scaling factor for the output of the attention mechanism. + mesh (Any): SPMD mesh for parallel computation. + """ + num_q_heads: int num_kv_heads: int @@ -1034,7 +1540,25 @@ class DecoderLayer(hk.Module): padding_mask: Optional[jax.Array], layer_memory: Optional[KVMemory], ) -> DecoderOutput: - """Transforms input embedding sequences to output embedding sequences.""" + """ + Transforms input embedding sequences to output embedding sequences. + Processes input embeddings through a single layer of the decoder. + + This method applies multi-head attention followed by position-wise feed-forward networks, + including any necessary normalization and skip connections, as per the transformer architecture. + + Args: + inputs (jax.Array): Input embeddings, shaped [batch_size, seq_length, model_dim]. + mask (jax.Array): Attention mask, shaped [batch_size, 1, seq_length, seq_length], used to prevent + attention to future positions. + padding_mask (Optional[jax.Array]): Mask indicating which positions are padding tokens, + to exclude them from attention calculations. + layer_memory (Optional[KVMemory]): Memory state for storing past key/value pairs for efficient + autoregressive decoding. + + Returns: + DecoderOutput: Named tuple containing output embeddings and updated memory state. + """ def layer_norm(x): return hk_rms_norm(x) @@ -1103,12 +1627,35 @@ class DecoderLayer(hk.Module): class LanguageModelOutput(NamedTuple): + """ + Represents the output of the language model after processing an input sequence of tokens. + + This output encapsulates both the logits representing the model's predictions for the next token + in the sequence and the updated model state, which includes any memory or state information + that needs to be carried over for generating subsequent tokens. + + Attributes: + logits (jax.Array): The logits for the next token predictions, shaped as [batch_size, sequence_length, vocab_size]. + model_state (Any): The updated state of the model after processing the input sequence, which may include + memory states for layers that utilize recurrence or caching for efficiency. + """ logits: jax.Array model_state: Any class InOutEmbed(hk.Embed): - """Module for embedding tokens in a low-dimensional space.""" + """ + A module for embedding input tokens into a low-dimensional space continuous vector space and for projecting the outputs of + a transformer back into the vocabulary space. This module supports tying the weights between the input + embedding and the output projection for parameter efficiency. + + Attributes: + vocab_size (Optional[int]): The size of the vocabulary. + embed_dim (Optional[int]): The dimensionality of the embedding vectors. + sharding (Optional[PartitionSpec]): Specifies how the embedding parameters should be sharded across + devices for distributed computation. + name (Optional[str]): An optional name for the module. + """ def __init__( self, @@ -1117,6 +1664,17 @@ class InOutEmbed(hk.Embed): sharding: Optional[P] = None, name: Optional[str] = None, ): + """ + Initializes an embedding module that can be used for both input token embeddings and output logits projection in a transformer model. + + This shared embedding layer helps reduce the number of parameters by tying the weights between the input embedding and the output projection layers. + + Args: + vocab_size (Optional[int]): The size of the vocabulary. + embed_dim (Optional[int]): The dimensionality of the embedding vectors. + sharding (Optional[PartitionSpec]): The sharding specification for distributing the embedding parameters across devices. + name (Optional[str]): An optional name for the module. + """ super().__init__( vocab_size=vocab_size, embed_dim=embed_dim, @@ -1126,6 +1684,15 @@ class InOutEmbed(hk.Embed): @property def embeddings(self): + """ + Retrieves the embedding matrix from the module's parameters. + + This method is useful for operations that need direct access to the embeddings, such as output projection in language models. + + Returns: + jax.Array: The embedding matrix, shaped as [vocab_size, embed_dim]. + """ + embed_mat = hk.get_parameter( "embeddings", [self.vocab_size, self.embed_dim], @@ -1140,12 +1707,41 @@ class InOutEmbed(hk.Embed): self, inputs: jax.Array, ) -> jax.Array: + """ + Projects transformer model outputs back into the vocabulary space using the transposed embedding matrix. + + This method effectively performs the inverse of the embedding operation, converting model output embeddings into logits over the vocabulary. + + Args: + inputs (jax.Array): The output embeddings from the transformer model, shaped as [batch_size, seq_length, embed_dim]. + + Returns: + jax.Array: The logits over the vocabulary for each token position, shaped as [batch_size, seq_length, vocab_size]. + """ return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype)) @dataclass class LanguageModelConfig: - """An autoregressive transformer-based language model.""" + """ + Configuration class for an autoregressive language model based on the Transformer architecture. + + Attributes: + model (TransformerConfig): The transformer model configuration. + vocab_size (int): The size of the vocabulary. + pad_token (int): The token used for padding sequences. + eos_token (int): The end-of-sentence token. + sequence_len (int): The maximum sequence length the model can handle. + model_size (int): The dimensionality of the model embeddings. + embedding_init_scale (float): Initial scale for embedding parameter initialization. + embedding_multiplier_scale (float): Multiplier for scaling the embedding vectors. + output_multiplier_scale (float): Multiplier for scaling the output logits. + name (Optional[str]): Name of the language model configuration. + fprop_dtype (Any): Data type for forward propagation computations. + model_type (Optional[str]): Type of the model, if applicable. + init_scale_override (Optional[float]): Override for the initial scale of parameters, if needed. + shard_embeddings (bool): Whether to shard embeddings across the specified axes. + """ model: Optional[TransformerConfig] vocab_size: int @@ -1200,7 +1796,23 @@ def layer_norm(x, model): @dataclass class LanguageModel(hk.Module): - """An autoregressive transformer-based language model.""" + """ + A high-level module for autoregressive language modeling using a Transformer architecture. This module + integrates components such as embedding layers, transformer blocks, and output layers to process sequences + of tokens and generate predictions for the next tokens in the sequence. + + The LanguageModel is designed for tasks such as text generation, where it can be used to produce coherent + and contextually relevant text based on a given prompt. + + Attributes: + model (Transformer): The core transformer model used for processing input token sequences. + config (LanguageModelConfig): Configuration parameters for the language model, including details about + the architecture, embeddings, and output processing. + fprop_dtype (Any): The data type to use for forward propagation calculations, typically set to jnp.bfloat16 + for efficiency. + name (Optional[str]): An optional name for the module. Useful for distinguishing between multiple instances. + mesh (Any): The SPMD mesh for parallel computation, supporting distributed training and inference. + """ model: "Transformer" config: LanguageModelConfig @@ -1217,7 +1829,22 @@ class LanguageModel(hk.Module): last_hid_only: bool = False, length: Optional[jax.Array] = None, ) -> LanguageModelOutput: - """Forward pass, producing a sequence of logits.""" + """ + Forward pass, producing a sequence of logits. + Generates logits for the next token predictions based on input tokens and optional memory state. + + Args: + tokens (jax.Array): Input tokens to the language model, shaped as [batch_size, seq_length]. + memory (Optional[Memory]): Optional memory state from previous steps, for autoregressive generation. + batch (Dict[str, jax.Array]): Additional batch information, unused here. + last_hid_only (bool): If True, returns only the last hidden state instead of logits. + length (Optional[jax.Array]): Specifies the length of each sequence in the batch for processing + only up to those lengths. + + Returns: + LanguageModelOutput: A named tuple containing the logits for next token predictions and the + updated memory state. + """ del batch # Unused. config = self.config @@ -1279,9 +1906,31 @@ class LanguageModel(hk.Module): ) def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16): + """ + Initializes the memory for the language model, suitable for autoregressive sequence generation tasks. + + Args: + batch_size (int): The number of sequences for which to initialize memory. + seq_len (int): The length of sequences for memory allocation. + dtype (Any): Data type for the memory arrays, typically jnp.bfloat16. + + Returns: + Memory: The initialized memory structure for storing keys, values, and decoding steps across transformer layers. + """ return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype) def prefill_memory(self, prompts, memory): + """ + Optionally pre-fills the transformer's memory with information from a provided prompt, enhancing efficiency + in subsequent autoregressive generation steps by caching necessary computations from the prompt processing. + + Args: + prompts (jax.Array): The prompt tokens from which to generate subsequent tokens. + memory (Memory): The memory state to update with information derived from the prompts. + + Returns: + Tuple[jax.Array, Memory]: The logits produced by processing the prompts and the updated memory state. + """ # Pad to the left and right align? # Basically assume prompt is already padded model_output = self(prompts, memory=memory) @@ -1290,7 +1939,28 @@ class LanguageModel(hk.Module): @dataclass class Transformer(hk.Module): - """A transformer stack.""" + """ + Core transformer module that implements the foundational architecture of a transformer-based model. This module + is capable of processing sequences of embeddings through multiple layers of self-attention and feed-forward + networks, optionally including advanced techniques like mixture of experts (MoE) and activation sharding + for efficient large-scale parallel computation. + + Attributes: + num_q_heads (int): Number of heads in the query part of the multi-head attention mechanism. + num_kv_heads (int): Number of heads for the keys and values in the multi-head attention. + key_size (int): Dimensionality of the key (and query) vectors in the attention mechanism. + widening_factor (float): Factor by which to widen the dimensionality of the feed-forward network relative to the embeddings. + init_scale (float): Initial scale for parameter initialization. + mesh (Any): The SPMD mesh for parallel computation. + attn_output_multiplier (float): Multiplier for the output of the attention mechanism. + shard_activations (bool): Whether to shard activations across devices in distributed settings. + num_layers (int): Number of transformer layers to stack in the model. + num_experts (int): Number of experts in the MoE layer, if used. + num_selected_experts (int): Number of experts selected for each input token in the MoE layer. + data_axis (Union[str, Tuple[str, ...]]): Axis names for sharding data across devices. + model_axis (Union[str, Tuple[str, ...]]): Axis names for sharding model parameters across devices. + name (Optional[str]): An optional name for the module. + """ num_q_heads: int num_kv_heads: int @@ -1311,6 +1981,20 @@ class Transformer(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model" def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16): + """ + Initializes the memory state for the transformer model. + + This is particularly useful for autoregressive tasks where past key and value pairs are cached + to improve efficiency in generating sequences. + + Args: + batch_size (int): The batch size for which to initialize memory states. + sequence_len (int): The sequence length for initializing the size of memory buffers. + dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency. + + Returns: + Memory: A named tuple representing the initialized memory state for each layer. + """ return Memory( layers=init_layer_memories( batch_size, @@ -1329,7 +2013,22 @@ class Transformer(hk.Module): mask: jax.Array, # [B, T] memory: Optional[Memory], ) -> TransformerOutput: - """Transforms input embedding sequences to output embedding sequences.""" + """ + Processes input embeddings through the transformer model. + Transforms input embedding sequences to output embedding sequences. + + Args: + embeddings (jax.Array): Input embeddings to be processed by the transformer, shaped as + [batch_size, seq_length, model_dim]. + mask (jax.Array): Mask indicating valid positions within the input, to control which positions + are allowed to attend to each other, shaped as [batch_size, seq_length]. + memory (Optional[Memory]): Optional memory state for the transformer to support autoregressive + decoding or similar use cases. + + Returns: + TransformerOutput: A named tuple containing the transformed embeddings and the final state + of the memory after processing. + """ fprop_dtype = embeddings.dtype _, seq_len, model_size = embeddings.shape diff --git a/run.py b/run.py index f1e157a..9a06d15 100644 --- a/run.py +++ b/run.py @@ -22,6 +22,27 @@ CKPT_PATH = "./checkpoints/" def main(): + """ + Initializes and runs a text generation model using predefined model configurations and inference settings. + + This function sets up a language model with specific configurations, including model architecture details + (e.g., embedding sizes, number of layers, attention heads, and MoE settings) and text generation settings + (e.g., vocabulary size, token identifiers). It initializes an inference runner with the model, checkpoint + path, tokenizer, and mesh configuration. The inference runner is then used to generate text based on a + given prompt and output the result. + + The process involves: + - Creating a `LanguageModelConfig` instance with specified model parameters, including transformer + configurations and quantization settings for weights. + - Initializing an `InferenceRunner` with the model configuration, batch size per device, checkpoint path, + and other relevant settings. + - Calling the `initialize` method on the inference runner to prepare the model and tokenizer for inference. + - Generating text based on a provided prompt using the `sample_from_model` function, which internally + manages the sampling process through the inference runner. + + Output: + - Prints the generated text continuation for a prompt to the standard output. + """ grok_1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, diff --git a/runners.py b/runners.py index 452c142..e537c40 100644 --- a/runners.py +++ b/runners.py @@ -48,6 +48,20 @@ TOP_K = 8 class SampleSettings(NamedTuple): + """ + A NamedTuple for storing settings used during the sampling process. + + Attributes: + - temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make + the model outputs more deterministic, while higher values increase diversity. + - nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of + token probabilities, effectively truncating low-probability tokens to focus sampling on a subset + of plausible tokens. + - mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0) + for sampling in each batch element. + - active (ArrayLike): A binary indicator for whether a given batch element is actively being used for + generation. Elements not in use are ignored in computations to save resources. + """ temperature: ArrayLike nucleus_p: ArrayLike mask: ArrayLike @@ -56,6 +70,15 @@ class SampleSettings(NamedTuple): class SampleOutput(NamedTuple): + """ + A NamedTuple for storing the output from the sampling process. + + Attributes: + - token_id (ArrayLike): The sampled token ID for each batch element. + - prob (ArrayLike): The probability associated with the sampled token for each batch element. + - top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element. + - top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element. + """ token_id: ArrayLike prob: ArrayLike top_k_token_ids: ArrayLike @@ -63,6 +86,19 @@ class SampleOutput(NamedTuple): def insert_slice(memory: Memory, slice, length, i): + """ + Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory + object with a new slice that includes updated steps for each layer based on the provided length. + + Parameters: + - memory (Memory): The original memory object to be updated. + - slice: A slice of the memory to be inserted, typically representing a new or modified piece of information. + - length (int): The step length to set for each KVMemory layer in the inserted slice. + - i (int): The index at which the slice is to be inserted into the memory. + + Returns: + - Memory: The updated memory object with the new slice inserted at the specified index. + """ slice = Memory( layers=[ KVMemory(layer.k, layer.v, step=jnp.array([length])) @@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i): def pad_to_size(x, size): + """ + Pads or truncates an array to a specified size. If the array is longer than the target size, it will be + truncated from the left. If it is shorter, it will be right-padded with zeros. + + Parameters: + - x (ArrayLike): The array to be padded or truncated. + - size (int): The target size for the array. + + Returns: + - ArrayLike: The array adjusted to the specified size. + """ if x.shape[0] > size: # Left truncate if the context is too long. x = x[-size:] @@ -82,7 +129,20 @@ def pad_to_size(x, size): def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: - """Performs nucleus filtering on logits.""" + """ + Performs nucleus filtering on logits. + Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This + nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving + the quality of generated text. + + Parameters: + - logits (jax.Array): The logits to filter. + - top_p (jax.Array): The cumulative probability threshold for filtering logits. + + Returns: + - jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them + from consideration during sampling. + """ assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_probs = jax.nn.softmax(sorted_logits) @@ -102,6 +162,21 @@ def sample_token( lm_outputs: LanguageModelOutput, settings: SampleSettings, ) -> SampleOutput: + """ + Samples a token from the language model outputs using specified settings, including temperature and + nucleus probability filtering. This function also computes the probability of the sampled token and + identifies the top-k tokens and their probabilities. + + Parameters: + - rngs (jax.random.PRNGKey): The random number generator state for sampling. + - lm_outputs (LanguageModelOutput): The outputs from the language model, including logits. + - settings (SampleSettings): The settings controlling the sampling process, including temperature, + nucleus probability, token masking, and active batch elements indicator. + + Returns: + - SampleOutput: The results of the sampling process, including the sampled token ID, its probability, + and the top-k tokens and their probabilities for each batch element. + """ # Expand the settings shape to match the logit shape. settings = SampleSettings( temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. @@ -135,6 +210,25 @@ def sample_token( @dataclass class ModelRunner: + """ + Manages the execution of a language model, including initialization, forward passes, and state management. + + Attributes: + - model (LanguageModelConfig): The configuration for the language model. + - bs_per_device (float): The batch size per device. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading. + - load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading. + - rng_seed (int): The initial random number generator seed. + - transform_forward (bool): Flag to determine whether to transform the forward function with Haiku. + - checkpoint_path (str): The path to the directory containing model checkpoints for loading. + + Methods: + - make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku. + - initialize: Initializes the model runner with data and mesh configuration, preparing it for execution. + - init: Initializes the model parameters using provided data. + - get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data. + - load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified. + """ model: LanguageModelConfig bs_per_device: float = 2.0 @@ -148,6 +242,18 @@ class ModelRunner: checkpoint_path: str = "" def make_forward_fn(self, mesh: Any): + """ + Creates and optionally transforms the forward function of the model. This method constructs a forward + function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the + method also transforms this function using Haiku's transform method to prepare it for JIT compilation + and execution on the mesh. + + Parameters: + mesh (Any): The mesh configuration to be used for distributing the computation across devices. + + Returns: + Callable: The forward function, potentially transformed by Haiku, ready for execution. + """ def forward(tokens): out = self.model.make(mesh=mesh)(tokens) return out, None @@ -162,6 +268,20 @@ class ModelRunner: local_mesh_config: tuple[int, int], between_hosts_config: tuple[int, int], ): + """ + Initializes the model runner with necessary configurations and data. This includes setting up the mesh + for distributed computation, calculating batch sizes based on the provided configuration, and preparing + the forward function for execution. + + Parameters: + init_data: Initial data used for model initialization. This is typically dummy data matching the + expected input format and dimensions of the model. + local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices + are organized locally for parallel computation. + between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a + distributed setup, important for multi-host environments. + + """ num_replicas = math.prod(between_hosts_config) self.model.initialize() self.model.fprop_dtype = jnp.bfloat16 @@ -191,12 +311,37 @@ class ModelRunner: self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) def init(self, rng: jax.Array, data) -> TrainingState: + """ + Initializes the model parameters using provided data and a random number generator state. This method + is only called when `transform_forward` is True, indicating that the forward function has been transformed + by Haiku and requires explicit initialization. + + Parameters: + rng (jax.Array): The random number generator state for initializing model parameters. + data: The initial data used for parameter initialization, usually including input tokens and + possibly target tokens for supervised learning tasks. + + Returns: + TrainingState: An instance of TrainingState containing the initialized model parameters. + """ assert self.transform_forward rng, init_rng = jax.random.split(rng) params = self.forward.init(init_rng, data["inputs"]) return TrainingState(params=params) def get_state_sharding(self, init_data): + """ + Determines the sharding configuration for model parameters based on initialization data. This method + evaluates the shape of model parameters during initialization to apply the partition rules specified + in the model's configuration, facilitating efficient distributed computation. + + Parameters: + init_data: The initial data used for model initialization, which helps determine the appropriate + sharding configuration based on model parameter shapes. + + Returns: + A sharding configuration that specifies how model parameters should be partitioned across the mesh. + """ assert self.transform_forward rng = jax.random.PRNGKey(self.rng_seed) rank_logger.info(f"partition rules: {self.model.partition_rules}") @@ -215,6 +360,20 @@ class ModelRunner: from_checkpoint: bool = True, init_fn: Optional[Callable] = None, ): + """ + Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This + method provides flexibility in starting the model from a known state or freshly initializing parameters + based on the model configuration and provided initial data. + + Parameters: + init_data: Initial data used for model parameter initialization if needed. + from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint. + init_fn (Optional[Callable]): An optional initialization function that overrides the default + initialization method if provided. + + Returns: + The model state, either loaded from a checkpoint or newly initialized. + """ rng = jax.random.PRNGKey(self.rng_seed) if not self.checkpoint_path or not from_checkpoint: @@ -251,6 +410,16 @@ class ModelRunner: @dataclass class Request: + """ + Data class for storing information about a single sampling request to the model. + + Attributes: + prompt (str): The text prompt to generate text from. + temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs. + nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens. + rng_seed (int): Seed for the random number generator to ensure reproducibility of the results. + max_len (int): The maximum length of the generated text sequence. + """ prompt: str temperature: float nucleus_p: float @@ -260,6 +429,24 @@ class Request: @dataclass class InferenceRunner: + """ + Manages inference operations for generating text based on prompts, including initializing the model and tokenizer, + prefilling the model memory, and running the actual text generation. + + Attributes: + name (str): A name identifier for the inference runner. + runner (Any): An instance of ModelRunner that manages the underlying language model. + load (str): A path to the model checkpoint for loading. + tokenizer_path (str): Path to the SentencePiece tokenizer model file. + local_mesh_config (Tuple[int, int]): Configuration for the local device mesh. + between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup. + pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts. + + Methods: + get_pad_bucket: Determines the appropriate padding size for a given input length. + initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes. + run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps. + """ name: str runner: Any load: str @@ -269,10 +456,25 @@ class InferenceRunner: pad_sizes: tuple[int] = (1024,) def get_pad_bucket(self, size): + """ + Determines the appropriate padding size for a given input length. This method uses the predefined + pad sizes to find the smallest pad size that is larger than or equal to the given size. + + Parameters: + size (int): The length of the input sequence to be padded. + + Returns: + int: The selected pad size from the predefined pad sizes that best fits the input length. + """ i = bisect.bisect_left(self.pad_sizes, size) return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] def initialize(self): + """ + Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling + model functions for different input sizes, and loading or initializing model parameters. This method + ensures that the inference runner is ready to generate text based on prompts. + """ runner = self.runner self.runner.transform_forward = True dummy_data = dict( @@ -440,7 +642,15 @@ class InferenceRunner: ) def run(self): - """Generator that accepts prompts.""" + """ + Generator method that accepts prompts and manages the text generation process. This method handles + tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to + generate and return the generated text. + + Yields: + str: The generated text for each prompt received. This method operates as a coroutine, yielding + generated text back to the caller and awaiting new prompts for further generation. + """ runner = self.runner mesh = runner.mesh max_len = runner.model.sequence_len @@ -580,6 +790,18 @@ class InferenceRunner: def make_mesh( local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] ) -> jax.sharding.Mesh: + """ + Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices. + + Parameters: + local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds + to how local devices are arranged and partitioned for computation. + between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between + different hosts, relevant in multi-host distributed setups. + + Returns: + jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation. + """ assert len(local_mesh_config) == 2 assert len(between_hosts_config) == 2 rank_logger.info("Detected %s devices in mesh", jax.device_count()) @@ -594,6 +816,32 @@ def make_mesh( def sample_from_model(server, prompt, max_len, temperature): + """ + Samples tokens from a trained model given a prompt and sampling settings. This function is designed + to interact with a generator-based server that manages the state of the model and performs the actual + sampling. The server is expected to yield control back to this function, which then sends a request + object containing the prompt and settings, and waits for the generated output. + + The function initializes the server (if not already done), constructs the request with the provided + parameters, and sends this request to the server. The sampled output, typically a continuation of the + prompt, is then received from the server. + + Parameters: + - server (generator): The generator-based server that handles model inference. This server is expected + to manage the lifecycle of the model, including loading parameters, setting up the environment, and + performing the token sampling. + - prompt (str): The initial text prompt to which the model generates a continuation. This prompt is + encoded and fed into the model as part of the request. + - max_len (int): The maximum length of the generated sequence, including the length of the prompt. + This controls how long the model's output can be. + - temperature (float): A parameter controlling the randomness of the output. Lower values make the + model's outputs more deterministic (and potentially more repetitive), while higher values increase + diversity at the cost of coherence. + + Returns: + - str: The generated text as a continuation of the provided prompt, up to a maximum length specified + by `max_len`. + """ next(server) inp = Request( prompt=prompt,