diff --git a/model.py b/model.py index 62ee641..48ffcf2 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Grok-1 + +This architecture is designed for language modeling and autoregressive sequence generation, incorporating advanced features such as: + +- 8-bit Quantized Weights: Implements quantization for model parameters to reduce the memory footprint and potentially increase computational efficiency. +- Sharding and Distributed Computation: Utilizes JAX's capabilities for distributed computation across devices, optimizing parallel processing and memory usage. +- Memory Caching for Autoregressive Decoding: Features a mechanism for caching keys and values in attention layers, enhancing efficiency in sequence generation tasks. + +Core Components: +- Multi-Head Attention (MHA): Custom implementation, including rotary positional embeddings for implicit sequence position information. +- Mixture of Experts (MoE): Allows routing inputs to different "expert" networks based on the input data, increasing model capacity and expressiveness. +- Feed-Forward Networks (FFNs): Defines networks with adjustable sizes and support for quantized weights and custom linear layers with sharding. + +Training and Inference Workflow: +- Manages the model's parameters, caching mechanisms, and layer-wise states efficiently during both training and inference. +- Implements detailed sharding strategies for optimizing the distribution of computation and storage across multiple devices. + +Advanced Features: +- Custom Layer Normalization: Adopts RMSNorm for stabilizing the training of deep networks. +- Dynamic and Static Sharding: Offers flexibility in data and model parallelism, allowing dynamic adjustment of sharding constraints. + +Efficiency and Scalability: +- Efficiently manages data flow and computation, minimizing unnecessary data replication and movement across devices. +- Designed with scalability in mind, providing a foundation for training complex models on massive datasets. + +This architecture leverages JAX for high-performance numerical computing and automatic differentiation, alongside Haiku for modular and flexible deep learning model construction. + +Data Flow Through the Training Process: + +1. Input Preparation: + - The process begins with the preparation of input data, which typically involves tokenizing text data into numerical tokens that represent words or subwords in a vocabulary. + - Tokens are then batched and padded to ensure consistent sequence lengths within each batch, forming the initial input tensor for the model. + +2. Embedding Layer: + - The input tokens are passed through an embedding layer, transforming each token into a high-dimensional vector. This layer may utilize pre-trained embeddings or learn embeddings during training. + - Positional encodings or embeddings are added to these vectors to incorporate sequence position information. + +3. Transformer Layers: + - The sequence of embedding vectors is processed through multiple Transformer layers, each consisting of the following sub-layers: + a. Multi-Head Attention (MHA): Each layer computes self-attention for its input, allowing each position in the sequence to attend to all positions in the previous layer’s output. + b. Feed-Forward Network (FFN): After attention, the output for each position passes through a feed-forward network. FFNs are identical for different positions but have different parameters from layer to layer. + + - Between each sub-layer, residual connections followed by layer normalization are applied. This helps in stabilizing the training of deep networks. + +4. Caching and Memory: + - For autoregressive tasks, where the model generates sequences one token at a time, keys and values computed during the attention operations are cached. This mechanism allows reusing these computations in subsequent steps, reducing the computational load. + +5. Output Layer: + - The output from the final Transformer layer is passed through a linear layer or a decoder, transforming the high-dimensional representations back into the vocabulary space to produce logits for each token in the sequence. + +6. Loss Calculation and Backpropagation: + - The logits are compared against the true token sequences using a suitable loss function (e.g., cross-entropy for language modeling tasks). The loss quantifies the model's prediction accuracy. + - Based on the loss, gradients are computed for each parameter in the model using backpropagation. These gradients indicate how each parameter should be adjusted to minimize the loss. + +7. Parameter Update: + - Model parameters are updated using an optimization algorithm (e.g., Adam, SGD) and the computed gradients. This step adjusts the model’s weights to improve its predictions on the next iteration. + +8. Iteration and Convergence: + - Steps 1 through 7 are repeated for multiple epochs over the training dataset until the model's performance on a validation set converges or begins to degrade, indicating potential overfitting. + +This structured approach, combined with the Transformer architecture's capability to model complex dependencies in data, enables effective training of models for tasks such as language understanding, translation, and generation. +""" + import functools import logging import re @@ -70,14 +134,16 @@ class TrainingState(NamedTuple): def _match(qs, ks): """ - Checks if regex patterns in qs match any contiguous subsequence of strings in ks. + Determines if a sequence of regex patterns (qs) matches any contiguous subsequence of strings (ks). + This utility function is often used for matching parameter names or paths in a hierarchical structure. Args: qs (Sequence[str]): A sequence of regex patterns to match. ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched. Returns: - bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise. + bool: True if every pattern in qs has a corresponding match in a contiguous subsequence of ks, + otherwise False. """ # compile regexes and force complete match qts = tuple(map(lambda x: re.compile(x + "$"), qs)) @@ -90,14 +156,17 @@ def _match(qs, ks): def with_sharding_constraint(x, constraint): """ - Applies a sharding constraint to a JAX array, if a physical mesh is available. + Applies a sharding constraint to a JAX array. This function is used in SPMD programs to hint how the + data should be partitioned across devices. If a physical mesh is not available, it simply returns the + original array. Args: x (jax.Array): The array to apply the sharding constraint to. constraint (PartitionSpec): The sharding constraint to apply. Returns: - jax.Array: The array with the sharding constraint applied if a physical mesh is present. + jax.Array: The array with the sharding constraint applied, affecting its distribution across devices + in distributed computation setups. """ if jax.experimental.maps.thread_resources.env.physical_mesh.empty: return x @@ -107,13 +176,15 @@ def with_sharding_constraint(x, constraint): def cast_bfloat16(x): """ - Casts the input to bfloat16 type if it is of floating-point type. + Casts the input array to bfloat16 type if it is of floating-point type. This operation is often used to + reduce memory consumption and potentially increase computation speed by using lower precision. Args: x (jax.Array): The input array. Returns: - jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise. + jax.Array: The array cast to bfloat16 if the original array was floating-point; otherwise, the array + is returned unchanged. """ if x.dtype.kind == "f": return x.astype(jnp.bfloat16) @@ -241,6 +312,14 @@ TOP_K = 8 class KVMemory(NamedTuple): + """ + Represents key-value memory slots for a transformer layer, supporting efficient autoregressive decoding by caching past computed keys and values. + + Attributes: + k (Optional[jax.Array]): Cached keys, shaped as [batch_size, sequence_len, num_kv_heads, key_size]. + v (Optional[jax.Array]): Cached values, shaped as [batch_size, sequence_len, num_kv_heads, key_size]. + step (Optional[jax.Array]): The current decoding step, indicating how many positions have been generated, shaped as [batch_size]. + """ k: Optional[jax.Array] v: Optional[jax.Array] step: Optional[jax.Array] @@ -256,19 +335,19 @@ def init_layer_memories( dtype=jnp.bfloat16, ): """ - Initializes memory slots for each layer in the transformer model. + Initializes layer memories for each transformer layer, providing a mechanism for efficient sequence generation by caching keys and values. Args: - batch_size (int): The size of the batch. - sequence_len (int): The length of the sequence. - num_kv_heads (int): The number of key/value pairs per head. - key_size (int): The size of each key. - num_layers (int): The number of layers in the transformer. - step (Optional[jax.Array]): The initial step for the memory, defaults to None. - dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16. + batch_size (int): The number of sequences being processed in parallel. + sequence_len (int): The length of the sequences for which memory is allocated. + num_kv_heads (int): The number of key-value pairs per head in the attention mechanism. + key_size (int): The size of each key (and value) in the attention mechanism. + num_layers (int): The number of transformer layers for which memory is initialized. + step (Optional[jax.Array]): The initial decoding step for each sequence in the batch. Defaults to None, indicating no prior steps. + dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency. Returns: - List[KVMemory]: A list of initialized KVMemory instances for each layer. + List[KVMemory]: A list of initialized KVMemory instances for each layer in the transformer model. """ return [ KVMemory( @@ -281,6 +360,12 @@ def init_layer_memories( class Memory(NamedTuple): + """ + A named tuple representing the complete memory state of a transformer model, encapsulating key-value memory slots for all layers. + + Attributes: + layers (List[KVMemory]): A list of KVMemory instances, one for each layer in the transformer model. + """ # Self-attention key/value cache. layers: List[KVMemory] @@ -306,6 +391,17 @@ class Router(hk.Module): mesh: Any = None, name: str = "router", ): + """ + Initializes a router for directing inputs to experts in a Mixture of Experts (MoE) layer. + + Args: + num_selected_experts (int): The number of experts to select for each input. + data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded. + model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded. + shard_activations (bool): Indicates whether to shard activations according to the data and model axes. + mesh (Any): The SPMD mesh object for parallel computation. + name (str): An optional name for the module. + """ super().__init__(name) self.shard_activations = shard_activations self.data_axis = data_axis @@ -316,6 +412,20 @@ class Router(hk.Module): def compute_routing_prob( self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int ): + """ + Computes the routing probabilities for each input to be directed to each expert. + + This internal method calculates the logits that determine how inputs are distributed among the experts, + based on learned criteria. + + Args: + inputs (jax.Array): Input data to be routed. + padding_mask (Optional[jax.Array]): Mask indicating which inputs are padding and should not be considered. + num_experts (int): The total number of experts available. + + Returns: + A tuple containing the routing probabilities, logits, and a dummy placeholder for compatibility. + """ return self._compute_routing_prob(inputs, padding_mask, num_experts) @hk.transparent @@ -398,6 +508,19 @@ class MoELayer(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model", name: Optional[str] = "moe", ): + """ + Initializes a Mixture of Experts layer with specified configuration. + + Args: + num_experts (int): The total number of experts in the MoE layer. + layer_fn (Callable): The function defining the computation performed by each expert. + router (Router): The router that directs inputs to selected experts. + mesh (Any): The optional SPMD mesh for parallel computation. + shard_activations (bool): Whether to shard activations across distributed resources. + data_axis (Union[str, Tuple[str, ...]]): Specifies how data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Specifies how model parameters are sharded. + name (Optional[str]): An optional name for the module. + """ super().__init__(name) self.num_experts = num_experts self.layer_fn = layer_fn @@ -544,11 +667,11 @@ class MHAOutput(NamedTuple): class DecoderOutput(NamedTuple): """ - Encapsulates the output of a decoder layer within the transformer model. + Encapsulates the output from a decoder layer within the transformer model, including the transformed embeddings and any updated memory state. Attributes: - embeddings (jax.Array): The embeddings produced by the decoder layer. - memory (Any): The updated memory state after processing by the decoder layer. + embeddings (jax.Array): The embeddings produced by the decoder layer, shaped as [batch_size, seq_length, embedding_dim]. + memory (Any): The updated memory state after processing by the decoder layer, useful for autoregressive decoding tasks. """ embeddings: jax.Array memory: Any @@ -556,11 +679,11 @@ class DecoderOutput(NamedTuple): class TransformerOutput(NamedTuple): """ - Represents the final output of the transformer model. + Represents the final output from the transformer model, including the final set of embeddings and any memory states that have been updated through the model's layers. Attributes: - embeddings (jax.Array): The final output embeddings from the transformer. - memory (Any): The final state of the memory after passing through the transformer layers. + embeddings (jax.Array): The final output embeddings from the transformer, shaped as [batch_size, seq_length, embedding_dim]. + memory (Any): The final memory state of the model after all transformer layers have been applied. """ embeddings: jax.Array memory: Any @@ -569,25 +692,26 @@ class TransformerOutput(NamedTuple): @dataclass class TransformerConfig: """ - Configuration class for a Transformer model specifying the model's architecture and settings. + Configuration class for setting up a Transformer model's architecture and its specific parameters. + + This class defines key architectural features of the transformer, including the size of embeddings, + the dimensionality of keys and values in the attention mechanism, the number of layers, and more. + It also includes configurations for advanced features like Mixture of Experts (MoE) and activation sharding. Attributes: - emb_size (int): Embedding size used in the transformer. - key_size (int): Size of the keys in the multi-head attention mechanism. - num_q_heads (int): Number of query heads in the multi-head attention. - num_kv_heads (int): Number of key/value pairs per attention head. - num_layers (int): Number of layers in the transformer model. - vocab_size (int): The size of the vocabulary. - widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size. - attn_output_multiplier (float): Multiplier for the output of the attention mechanism. - name (Optional[str]): Name of the transformer configuration. - num_experts (int): Number of experts in a mixture of experts layer. - capacity_factor (float): Capacity factor for routing in MoE layers. - num_selected_experts (int): Number of experts selected in each MoE layer. - init_scale (float): Initial scale for parameter initialization. - shard_activations (bool): If True, activations will be sharded across the specified axes. - data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded. - model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded. + emb_size (int): The size of the embedding vectors. + key_size (int): The size of the key (and query) vectors in the attention mechanism. + num_q_heads (int): The number of heads in the query part of the multi-head attention mechanism. + num_kv_heads (int): The number of heads for keys and values in the multi-head attention. + num_layers (int): The total number of layers in the transformer model. + vocab_size (int): The size of the vocabulary that the model can understand. + widening_factor (float): The factor by which the dimensionality of the feed-forward networks is widened relative to the embedding size. + attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism, for controlling its magnitude. + shard_activations (bool): Whether to shard activations across devices for parallel processing. + num_experts (int): The number of experts in the Mixture of Experts (MoE) layer, if used. + num_selected_experts (int): The number of experts selected for each input in the MoE layer. + data_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which model parameters are sharded for distributed computation. """ emb_size: int key_size: int @@ -673,6 +797,10 @@ def make_attention_mask( dtype: Any = jnp.bfloat16, ): """Mask-making helper for attention weights. + + Creates an attention mask to specify which tokens in the key sequences can be attended to by each token in the query sequences. + + This utility is used in attention mechanisms to control the visibility of tokens, for purposes such as preventing future tokens from being attended to in autoregressive models. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the attention weights will be `[batch..., heads, len_q, len_kv]` and this @@ -729,7 +857,17 @@ class Linear(hk.Linear): self, inputs: jax.Array, ) -> jax.Array: - """Computes a linear transform of the input.""" + """ + Computes a linear transform of the input data. + + This method computes the matrix multiplication between the inputs and the layer's weight matrix, optionally adding a bias term. + + Args: + inputs (jax.Array): The input tensor to be transformed, shaped as [batch_size, ..., input_features]. + + Returns: + jax.Array: The transformed tensor, shaped as [batch_size, ..., output_features]. + """ fprop_dtype = inputs.dtype if not inputs.shape: @@ -796,6 +934,17 @@ class RMSNorm(hk.RMSNorm): self.sharding = sharding def __call__(self, inputs: jax.Array): + """ + Applies RMS normalization to the input tensor. + + This method normalizes the inputs by their root mean square value, optionally scaling the result by a learnable parameter to adjust the representation scale. + + Args: + inputs (jax.Array): The input tensor to be normalized, shaped as [batch_size, ..., features]. + + Returns: + jax.Array: The RMS-normalized tensor, maintaining the input shape. + """ fprop_dtype = inputs.dtype param_shape = (inputs.shape[-1],) if self.create_scale: @@ -865,6 +1014,21 @@ class RotaryEmbedding(hk.Module): const_position: Optional[int] = None, t: Optional[jax.Array] = None, ) -> jax.Array: + """ + Applies Rotary Position Embedding (RoPE) to the input embeddings. + + This method dynamically encodes positional information by applying a rotation based on the position in the sequence, enhancing the model's ability to capture sequential dependencies. + + Args: + x (jax.Array): Input embeddings to apply RoPE to, shaped as [batch_size, seq_length, embedding_dim]. + seq_dim (int): The dimension index of the sequence length in the input tensor. + offset (jax.Array): The offset to apply to the position indices, useful for continuation of sequences across batches. + const_position (Optional[int]): A constant position value to use for all positions, instead of a range. + t (Optional[jax.Array]): Explicit tensor of position values, shaped as [seq_length]. + + Returns: + jax.Array: The input embeddings with RoPE applied, maintaining the input shape. + """ fprop_dtype = x.dtype # Compute the per-dimension frequencies exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32) @@ -897,6 +1061,23 @@ class RotaryEmbedding(hk.Module): class MultiHeadAttention(hk.Module): + """ + Implements the Multi-Head Attention mechanism, a key component of Transformer architectures. + + This module allows each token in the input sequence to attend to all tokens in the key and value sequences, with multiple "heads" learning different attention patterns. + + Attributes: + num_q_heads (int): The number of heads for the queries. + num_kv_heads (int): The number of heads for the keys and values. + key_size (int): The size of the key vectors. + with_bias (bool): Whether to include bias terms in the linear transformations. + value_size (Optional[int]): The size of the value vectors, defaults to the key size if not specified. + model_size (Optional[int]): The size of the output dimension from the attention mechanism, defaults to the total size of all heads if not specified. + attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded for distributed computation. + name (Optional[str]): An optional name for the module. + """ def __init__( self, num_q_heads: int, @@ -911,6 +1092,21 @@ class MultiHeadAttention(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model", name: Optional[str] = None, ): + """ + Initializes the Multi-Head Attention module with the provided configuration parameters. + + Args: + num_q_heads (int): Number of query heads for the multi-head attention mechanism. + num_kv_heads (int): Number of key/value heads for the multi-head attention mechanism. + key_size (int): Dimensionality of key vectors in the attention mechanism. + with_bias (bool): Whether to include a bias term in the attention weight computation. + value_size (Optional[int]): Dimensionality of value vectors, defaults to key size if not specified. + model_size (Optional[int]): Overall size of the model's output dimension. + attn_output_multiplier (float): Multiplier for scaling the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): Data sharding axis names for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Model parameter sharding axis names for distributed computation. + name (Optional[str]): An optional name for the module. + """ super().__init__(name=name) self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads @@ -1121,6 +1317,22 @@ class MultiHeadAttention(hk.Module): name: Optional[str] = None, mesh: Any = None, ) -> jax.Array: + """ + Projects input embeddings into multiple head spaces for queries, keys, or values. + + This internal method creates separate embeddings for each attention head by applying a linear transformation, + allowing the multi-head attention mechanism to explore different subspace relationships in the data. + + Args: + x (jax.Array): Input tensor to project, shaped as [batch_size, seq_length, embedding_dim]. + head_size (int): The dimensionality of each head's subspace. + num_heads (int): The number of heads to project into. + name (Optional[str]): A name for the operation, distinguishing between queries, keys, and values. + sharding (Optional[PartitionSpec]): The sharding specification for distributing computation across devices. + + Returns: + jax.Array: Projected embeddings for multiple heads, shaped as [batch_size, seq_length, num_heads, head_size]. + """ y = Linear( num_heads * head_size, with_bias=False, @@ -1134,7 +1346,20 @@ class MultiHeadAttention(hk.Module): @dataclass class MHABlock(hk.Module): - """A MHA Block""" + """ + A specialized module encapsulating the Multi-Head Attention (MHA) operation within a transformer model. + + This module orchestrates the application of MHA, including the computation of queries, keys, and values, and the subsequent attention and aggregation operations. + + Attributes: + num_q_heads (int): Number of heads for the query part of the MHA. + num_kv_heads (int): Number of heads for the key and value parts of the MHA. + key_size (int): Size of the keys (and queries) in the attention mechanism. + attn_output_multiplier (float): Scaling factor applied to the output of the attention mechanism. + data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded for distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded for distributed computation. + mesh (Any): The SPMD mesh for parallel computation. + """ num_q_heads: int num_kv_heads: int @@ -1151,6 +1376,20 @@ class MHABlock(hk.Module): mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1] layer_memory: Optional[KVMemory], ) -> MHAOutput: + """ + Processes inputs through a Multi-Head Attention block. + + This method applies multi-head attention to the inputs using the provided mask for attention scoring, + and optionally utilizes cached memory for keys and values to enhance efficiency in autoregressive models. + + Args: + inputs (jax.Array): Input embeddings, shaped as [batch_size, seq_length, embedding_dim]. + mask (jax.Array): Attention mask, shaped as [batch_size, 1, seq_length, seq_length], to control visibility between tokens. + layer_memory (Optional[KVMemory]): Cached keys and values from previous steps for efficient attention computation in autoregressive decoding. + + Returns: + MHAOutput: The output from the multi-head attention block, including the transformed embeddings and updated memory. + """ _, _, model_size = inputs.shape assert mask.ndim == 4, f"shape: {mask.shape}" assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape) @@ -1183,6 +1422,19 @@ class MHABlock(hk.Module): @dataclass class DenseBlock(hk.Module): + """ + Implements a dense (fully connected) block within a transformer layer, typically following multi-head attention. + + This block applies one or more linear transformations to its inputs, often including non-linear activations and potentially other operations like dropout for regularization. + + Attributes: + num_q_heads (int): The number of heads for the query in the preceding MHA layer. + num_kv_heads (int): The number of key/value pairs per head in the MHA layer. + key_size (int): The size of the keys in the MHA layer. + widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased. + sharding_constraint (bool): Whether to apply a sharding constraint for distributed computation. + mesh (Any): The SPMD mesh for parallel computation. + """ num_q_heads: int num_kv_heads: int key_size: int @@ -1195,6 +1447,18 @@ class DenseBlock(hk.Module): self, inputs: jax.Array, # [B, T, D] ) -> jax.Array: # [B, T, D] + """ + Applies a series of dense transformations to the inputs. + + This method constitutes the feedforward network part of a transformer layer, applying linear transformations + followed by non-linear activations to model complex data relationships beyond what attention mechanisms capture. + + Args: + inputs (jax.Array): Input embeddings from the previous layer or block, shaped as [batch_size, seq_length, embedding_dim]. + + Returns: + jax.Array: The output embeddings after applying the dense block transformations, maintaining the input shape. + """ _, _, model_size = inputs.shape h_v = Linear( ffn_size( @@ -1230,7 +1494,28 @@ class DenseBlock(hk.Module): @dataclass class DecoderLayer(hk.Module): - """A transformer stack.""" + """ + Represents a single layer in the decoder stack of a transformer model. This layer processes input embeddings through + a multi-head attention mechanism followed by position-wise feed-forward networks, with normalization and skip connections + applied as per the standard transformer architecture. + + Attributes: + num_q_heads (int): Number of query heads for multi-head attention. + num_kv_heads (int): Number of key/value pairs per attention head. + key_size (int): Size of keys in the attention mechanism. + num_layers (int): Total number of transformer layers. + num_experts (int): Number of experts in the Mixture of Experts layer, if used. + layer_index (Optional[int]): Index of this layer within the overall model; used for layer-specific configurations or logging. + num_selected_experts (int): Number of experts selected for each input in the MoE layer. + widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased. + name (Optional[str]): An optional name for the layer. + data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding in distributed computation. + model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding in distributed computation. + shard_activations (bool): Whether activations should be sharded across devices. + attn_output_multiplier (float): Scaling factor for the output of the attention mechanism. + mesh (Any): SPMD mesh for parallel computation. + """ + num_q_heads: int num_kv_heads: int @@ -1342,12 +1627,35 @@ class DecoderLayer(hk.Module): class LanguageModelOutput(NamedTuple): + """ + Represents the output of the language model after processing an input sequence of tokens. + + This output encapsulates both the logits representing the model's predictions for the next token + in the sequence and the updated model state, which includes any memory or state information + that needs to be carried over for generating subsequent tokens. + + Attributes: + logits (jax.Array): The logits for the next token predictions, shaped as [batch_size, sequence_length, vocab_size]. + model_state (Any): The updated state of the model after processing the input sequence, which may include + memory states for layers that utilize recurrence or caching for efficiency. + """ logits: jax.Array model_state: Any class InOutEmbed(hk.Embed): - """Module for embedding tokens in a low-dimensional space.""" + """ + A module for embedding input tokens into a low-dimensional space continuous vector space and for projecting the outputs of + a transformer back into the vocabulary space. This module supports tying the weights between the input + embedding and the output projection for parameter efficiency. + + Attributes: + vocab_size (Optional[int]): The size of the vocabulary. + embed_dim (Optional[int]): The dimensionality of the embedding vectors. + sharding (Optional[PartitionSpec]): Specifies how the embedding parameters should be sharded across + devices for distributed computation. + name (Optional[str]): An optional name for the module. + """ def __init__( self, @@ -1356,6 +1664,17 @@ class InOutEmbed(hk.Embed): sharding: Optional[P] = None, name: Optional[str] = None, ): + """ + Initializes an embedding module that can be used for both input token embeddings and output logits projection in a transformer model. + + This shared embedding layer helps reduce the number of parameters by tying the weights between the input embedding and the output projection layers. + + Args: + vocab_size (Optional[int]): The size of the vocabulary. + embed_dim (Optional[int]): The dimensionality of the embedding vectors. + sharding (Optional[PartitionSpec]): The sharding specification for distributing the embedding parameters across devices. + name (Optional[str]): An optional name for the module. + """ super().__init__( vocab_size=vocab_size, embed_dim=embed_dim, @@ -1365,6 +1684,15 @@ class InOutEmbed(hk.Embed): @property def embeddings(self): + """ + Retrieves the embedding matrix from the module's parameters. + + This method is useful for operations that need direct access to the embeddings, such as output projection in language models. + + Returns: + jax.Array: The embedding matrix, shaped as [vocab_size, embed_dim]. + """ + embed_mat = hk.get_parameter( "embeddings", [self.vocab_size, self.embed_dim], @@ -1379,6 +1707,17 @@ class InOutEmbed(hk.Embed): self, inputs: jax.Array, ) -> jax.Array: + """ + Projects transformer model outputs back into the vocabulary space using the transposed embedding matrix. + + This method effectively performs the inverse of the embedding operation, converting model output embeddings into logits over the vocabulary. + + Args: + inputs (jax.Array): The output embeddings from the transformer model, shaped as [batch_size, seq_length, embed_dim]. + + Returns: + jax.Array: The logits over the vocabulary for each token position, shaped as [batch_size, seq_length, vocab_size]. + """ return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype)) @@ -1458,18 +1797,21 @@ def layer_norm(x, model): @dataclass class LanguageModel(hk.Module): """ - A transformer-based language model for generating or evaluating sequences of tokens. + A high-level module for autoregressive language modeling using a Transformer architecture. This module + integrates components such as embedding layers, transformer blocks, and output layers to process sequences + of tokens and generate predictions for the next tokens in the sequence. - This class encapsulates a transformer model and provides methods for its initialization, - running the model forward to generate logits, and handling memory states for efficient - autoregressive generation. + The LanguageModel is designed for tasks such as text generation, where it can be used to produce coherent + and contextually relevant text based on a given prompt. Attributes: - model (Transformer): The underlying transformer model. - config (LanguageModelConfig): Configuration for the language model. - fprop_dtype (Any): Data type for forward propagation computations. - name (Optional[str]): Optional name for the module. - mesh (Any): The SPMD mesh for parallel computation, if applicable. + model (Transformer): The core transformer model used for processing input token sequences. + config (LanguageModelConfig): Configuration parameters for the language model, including details about + the architecture, embeddings, and output processing. + fprop_dtype (Any): The data type to use for forward propagation calculations, typically set to jnp.bfloat16 + for efficiency. + name (Optional[str]): An optional name for the module. Useful for distinguishing between multiple instances. + mesh (Any): The SPMD mesh for parallel computation, supporting distributed training and inference. """ model: "Transformer" @@ -1564,9 +1906,31 @@ class LanguageModel(hk.Module): ) def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16): + """ + Initializes the memory for the language model, suitable for autoregressive sequence generation tasks. + + Args: + batch_size (int): The number of sequences for which to initialize memory. + seq_len (int): The length of sequences for memory allocation. + dtype (Any): Data type for the memory arrays, typically jnp.bfloat16. + + Returns: + Memory: The initialized memory structure for storing keys, values, and decoding steps across transformer layers. + """ return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype) def prefill_memory(self, prompts, memory): + """ + Optionally pre-fills the transformer's memory with information from a provided prompt, enhancing efficiency + in subsequent autoregressive generation steps by caching necessary computations from the prompt processing. + + Args: + prompts (jax.Array): The prompt tokens from which to generate subsequent tokens. + memory (Memory): The memory state to update with information derived from the prompts. + + Returns: + Tuple[jax.Array, Memory]: The logits produced by processing the prompts and the updated memory state. + """ # Pad to the left and right align? # Basically assume prompt is already padded model_output = self(prompts, memory=memory) @@ -1576,28 +1940,26 @@ class LanguageModel(hk.Module): @dataclass class Transformer(hk.Module): """ - Core transformer model class implementing a stack of transformer layers. - - This class is designed to be flexible and configurable, supporting features like - multi-head attention, feed-forward networks, and optional mixture of experts layers. - It is capable of processing sequences of embeddings and returning transformed sequences - of embeddings, along with updated memory states for autoregressive tasks. + Core transformer module that implements the foundational architecture of a transformer-based model. This module + is capable of processing sequences of embeddings through multiple layers of self-attention and feed-forward + networks, optionally including advanced techniques like mixture of experts (MoE) and activation sharding + for efficient large-scale parallel computation. Attributes: - num_q_heads (int): Number of query heads for multi-head attention. - num_kv_heads (int): Number of key/value pairs per attention head. - key_size (int): Size of keys in the attention mechanism. - widening_factor (float): Widening factor for the dimensionality of the feed-forward network. - init_scale (float): Scale for parameter initialization. + num_q_heads (int): Number of heads in the query part of the multi-head attention mechanism. + num_kv_heads (int): Number of heads for the keys and values in the multi-head attention. + key_size (int): Dimensionality of the key (and query) vectors in the attention mechanism. + widening_factor (float): Factor by which to widen the dimensionality of the feed-forward network relative to the embeddings. + init_scale (float): Initial scale for parameter initialization. mesh (Any): The SPMD mesh for parallel computation. attn_output_multiplier (float): Multiplier for the output of the attention mechanism. - shard_activations (bool): Whether to shard activations across specified axes. - num_layers (int): Number of transformer layers. - num_experts (int): Number of experts for mixture of experts layers, if applicable. - num_selected_experts (int): Number of experts selected per token, for MoE layers. - name (Optional[str]): Name of the transformer model. - data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding. - model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding. + shard_activations (bool): Whether to shard activations across devices in distributed settings. + num_layers (int): Number of transformer layers to stack in the model. + num_experts (int): Number of experts in the MoE layer, if used. + num_selected_experts (int): Number of experts selected for each input token in the MoE layer. + data_axis (Union[str, Tuple[str, ...]]): Axis names for sharding data across devices. + model_axis (Union[str, Tuple[str, ...]]): Axis names for sharding model parameters across devices. + name (Optional[str]): An optional name for the module. """ num_q_heads: int