From 429a83e5d928235c641959eba91ef710dad76ef2 Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Mon, 18 Mar 2024 15:57:49 -0500 Subject: [PATCH] Added docstrings to multiple modules and methods --- model.py | 363 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 350 insertions(+), 13 deletions(-) diff --git a/model.py b/model.py index 87d700d..62ee641 100644 --- a/model.py +++ b/model.py @@ -36,6 +36,13 @@ rank_logger = logging.getLogger("rank") @dataclass class QuantizedWeight8bit: + """ + Represents an 8-bit quantized weight for neural network parameters. + + Attributes: + weight (jnp.array): The quantized weights. + scales (jnp.array): The scale factors used for quantization. + """ weight: jnp.array scales: jnp.array @@ -52,13 +59,26 @@ tree_util.register_pytree_node( class TrainingState(NamedTuple): - """Container for the training state.""" + """Container for the training state, encapsulating model parameters. + + Attributes: + params (hk.Params): The parameters of the model. + """ params: hk.Params def _match(qs, ks): - """Return True if regexes in qs match any window of strings in tuple ks.""" + """ + Checks if regex patterns in qs match any contiguous subsequence of strings in ks. + + Args: + qs (Sequence[str]): A sequence of regex patterns to match. + ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched. + + Returns: + bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise. + """ # compile regexes and force complete match qts = tuple(map(lambda x: re.compile(x + "$"), qs)) for i in range(len(ks) - len(qs) + 1): @@ -69,6 +89,16 @@ def _match(qs, ks): def with_sharding_constraint(x, constraint): + """ + Applies a sharding constraint to a JAX array, if a physical mesh is available. + + Args: + x (jax.Array): The array to apply the sharding constraint to. + constraint (PartitionSpec): The sharding constraint to apply. + + Returns: + jax.Array: The array with the sharding constraint applied if a physical mesh is present. + """ if jax.experimental.maps.thread_resources.env.physical_mesh.empty: return x else: @@ -76,6 +106,15 @@ def with_sharding_constraint(x, constraint): def cast_bfloat16(x): + """ + Casts the input to bfloat16 type if it is of floating-point type. + + Args: + x (jax.Array): The input array. + + Returns: + jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise. + """ if x.dtype.kind == "f": return x.astype(jnp.bfloat16) else: @@ -83,6 +122,18 @@ def cast_bfloat16(x): def ffn_size(emb_size, widening_factor): + """ + Calculates the size of the feed-forward network (FFN) based on the embedding size and a widening factor. + + The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations. + + Args: + emb_size (int): The size of the embeddings. + widening_factor (float): The factor by which to widen the FFN relative to the embedding size. + + Returns: + int: The adjusted size of the FFN. + """ _ffn_size = int(widening_factor * emb_size) * 2 // 3 _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8 logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}") @@ -90,6 +141,20 @@ def ffn_size(emb_size, widening_factor): def apply_rules(rules): + """ + Constructs a function to apply a set of sharding rules for transformer parameters. + + This function is used to determine the sharding specifications for model parameters based on their roles + and positions within the model architecture. + + Args: + rules (List[Tuple[Sequence[str], PartitionSpec]]): A list of tuples where each tuple contains a sequence + of strings representing the parameter path and the corresponding `PartitionSpec` to apply. + + Returns: + Callable: A function that takes a parameter path and returns the appropriate `PartitionSpec` based + on the provided rules. + """ def _apply_rules(path, value): del value # Unused. @@ -190,6 +255,21 @@ def init_layer_memories( step: Optional[jax.Array] = None, dtype=jnp.bfloat16, ): + """ + Initializes memory slots for each layer in the transformer model. + + Args: + batch_size (int): The size of the batch. + sequence_len (int): The length of the sequence. + num_kv_heads (int): The number of key/value pairs per head. + key_size (int): The size of each key. + num_layers (int): The number of layers in the transformer. + step (Optional[jax.Array]): The initial step for the memory, defaults to None. + dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16. + + Returns: + List[KVMemory]: A list of initialized KVMemory instances for each layer. + """ return [ KVMemory( k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype), @@ -206,6 +286,17 @@ class Memory(NamedTuple): class Router(hk.Module): + """ + A module for routing inputs to experts in a Mixture of Experts (MoE) layer. + + Attributes: + num_selected_experts (int): Number of experts to select for each input. + data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding. + model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding. + shard_activations (bool): If True, shard activations according to the data and model axes. + mesh (Any): The SPMD mesh for parallel computation. + name (str): The name of the module. + """ def __init__( self, num_selected_experts: int, @@ -234,6 +325,19 @@ class Router(hk.Module): padding_mask: Optional[jax.Array], num_experts: int, ): + """ + Computes the routing probabilities for directing inputs to the appropriate experts. + + Args: + inputs (jax.Array): Input data to be routed, shaped as [batch_size, ..., input_dim]. + padding_mask (Optional[jax.Array]): An optional mask indicating padded elements in the input, + shaped as [batch_size, seq_length], where padded positions are False. + num_experts (int): The total number of experts available for routing. + + Returns: + A tuple containing routing probabilities, routing logits, and a dummy value for compatibility, + shaped as ([batch_size, seq_length, num_experts], [batch_size, seq_length, num_experts], int). + """ # Using fp32 for the routing prob computation. inputs = jax.lax.convert_element_type(inputs, jnp.float32) @@ -270,6 +374,19 @@ class Router(hk.Module): class MoELayer(hk.Module): + """ + A module implementing a Mixture of Experts (MoE) layer. + + Attributes: + num_experts (int): The number of experts in the MoE layer. + layer_fn (Callable): The function to be applied by each expert. + router (Router): The router that routes inputs to experts. + mesh (Any): The SPMD mesh for parallel computation. + shard_activations (bool): If True, shard activations across data and model axes. + data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding. + model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding. + name (Optional[str]): The name of the module. + """ def __init__( self, num_experts: int, @@ -292,6 +409,18 @@ class MoELayer(hk.Module): @hk.transparent def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None): + """ + Handles the inference call to the MoE layer, distributing inputs to selected experts based on routing. + + Args: + inputs (jax.Array): Input data to be processed, shaped as [batch_size, seq_length, input_dim]. + padding_mask (Optional[jax.Array]): An optional mask for the inputs, where False indicates + positions that should not be processed (e.g., padding), shaped as [batch_size, seq_length]. + + Returns: + jax.Array: The processed outputs after passing through the selected experts, shaped as + [batch_size, seq_length, output_dim]. + """ routing_probs, _, _ = self.router.compute_routing_prob( inputs, padding_mask, self.num_experts ) @@ -401,24 +530,65 @@ class MoELayer(hk.Module): class MHAOutput(NamedTuple): - """Outputs of the multi-head attention operation.""" + """ + Represents the output of the Multi-Head Attention (MHA) operation. + + Attributes: + embeddings (jax.Array): The output embeddings from the MHA layer. + memory (Any): The updated memory state post-attention operation. + """ embeddings: jax.Array memory: Any class DecoderOutput(NamedTuple): + """ + Encapsulates the output of a decoder layer within the transformer model. + + Attributes: + embeddings (jax.Array): The embeddings produced by the decoder layer. + memory (Any): The updated memory state after processing by the decoder layer. + """ embeddings: jax.Array memory: Any class TransformerOutput(NamedTuple): + """ + Represents the final output of the transformer model. + + Attributes: + embeddings (jax.Array): The final output embeddings from the transformer. + memory (Any): The final state of the memory after passing through the transformer layers. + """ embeddings: jax.Array memory: Any @dataclass class TransformerConfig: + """ + Configuration class for a Transformer model specifying the model's architecture and settings. + + Attributes: + emb_size (int): Embedding size used in the transformer. + key_size (int): Size of the keys in the multi-head attention mechanism. + num_q_heads (int): Number of query heads in the multi-head attention. + num_kv_heads (int): Number of key/value pairs per attention head. + num_layers (int): Number of layers in the transformer model. + vocab_size (int): The size of the vocabulary. + widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size. + attn_output_multiplier (float): Multiplier for the output of the attention mechanism. + name (Optional[str]): Name of the transformer configuration. + num_experts (int): Number of experts in a mixture of experts layer. + capacity_factor (float): Capacity factor for routing in MoE layers. + num_selected_experts (int): Number of experts selected in each MoE layer. + init_scale (float): Initial scale for parameter initialization. + shard_activations (bool): If True, activations will be sharded across the specified axes. + data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded. + model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded. + """ emb_size: int key_size: int num_q_heads: int @@ -523,6 +693,20 @@ def make_attention_mask( class Linear(hk.Linear): + """ + Extends Haiku's Linear layer with optional sharding for use in distributed settings. + + This class allows specifying a `PartitionSpec` to shard the linear layer's weights across devices, + which can be beneficial in large-scale models processed over multiple devices or nodes. + + Args: + output_size (int): The size of the output dimension. + with_bias (bool, optional): Whether to include a bias term. Defaults to True. + sharding (Optional[P], optional): The sharding specification for distributing the layer's parameters. + mesh (Any, optional): The SPMD mesh for parallel computation. Defaults to None. + name (Optional[str], optional): An optional name for this module. Defaults to None. + shard_axis (int, optional): The axis along which to shard the input data. Defaults to 0. + """ def __init__( self, output_size: int, @@ -585,7 +769,21 @@ class Linear(hk.Linear): class RMSNorm(hk.RMSNorm): + """ + Implements Root Mean Square Layer Normalization. + This variant of layer normalization scales inputs by the root mean square of their elements, optionally + including a learnable scaling factor. It supports specifying a `PartitionSpec` for sharding the scale + parameters across devices in distributed settings. + + Args: + axis (Union[int, Sequence[int], slice]): The dimensions to normalize over. + eps (float, optional): A small constant added to the denominator to improve numerical stability. + Defaults to 1e-5. + name (Optional[str], optional): An optional name for this module. Defaults to None. + create_scale (bool, optional): Whether to include a learnable scaling factor. Defaults to True. + sharding (Optional[P], optional): The sharding specification for the scale parameter. Defaults to None. + """ def __init__( self, axis: Union[int, Sequence[int], slice], @@ -633,12 +831,19 @@ def rotate_half( class RotaryEmbedding(hk.Module): - """Applies rotary embeddings (RoPE) to the input sequence tensor, + """ + Implements Rotary Position Embedding (RoPE) to the input sequence tensor, as described in https://arxiv.org/abs/2104.09864. - Attributes: - dim (int): Dimensionality of the feature vectors - base_exponent (int): Base exponent to compute embeddings from + RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their + position in the sequence. This approach is designed to preserve the relative positional information across + different sequence lengths and tasks. + + Args: + dim (int): The dimensionality of the embeddings to be rotated, must be even. + name (Optional[str], optional): An optional name for this module. Defaults to None. + base_exponent (int, optional): The base of the exponent used to calculate rotary frequencies. + Defaults to 10000. """ def __init__( @@ -726,6 +931,22 @@ class MultiHeadAttention(hk.Module): kv_memory: Optional[KVMemory] = None, mesh: Any = None, ) -> MHAOutput: + """ + Computes the multi-head attention over the input queries, keys, and values. + + Args: + query (jax.Array): Query vectors, shaped as [batch_size, seq_length, model_dim]. + key (Optional[jax.Array]): Key vectors. If None, uses query as key. + value (Optional[jax.Array]): Value vectors. If None, uses query as value. + mask (Optional[jax.Array]): An optional mask to prevent attention to certain positions, + shaped as [batch_size, 1, seq_length, seq_length]. + kv_memory (Optional[KVMemory]): Optional memory for keys and values to support efficient + autoregressive decoding. + mesh (Any): The SPMD mesh for parallel computation, if applicable. + + Returns: + MHAOutput: A named tuple containing the output embeddings and updated memory. + """ # In shape hints below, we suppress the leading dims [...] for brevity. # Hence e.g. [A, B] should be read in every case as [..., A, B]. sequence_length = query.shape[1] @@ -1034,7 +1255,25 @@ class DecoderLayer(hk.Module): padding_mask: Optional[jax.Array], layer_memory: Optional[KVMemory], ) -> DecoderOutput: - """Transforms input embedding sequences to output embedding sequences.""" + """ + Transforms input embedding sequences to output embedding sequences. + Processes input embeddings through a single layer of the decoder. + + This method applies multi-head attention followed by position-wise feed-forward networks, + including any necessary normalization and skip connections, as per the transformer architecture. + + Args: + inputs (jax.Array): Input embeddings, shaped [batch_size, seq_length, model_dim]. + mask (jax.Array): Attention mask, shaped [batch_size, 1, seq_length, seq_length], used to prevent + attention to future positions. + padding_mask (Optional[jax.Array]): Mask indicating which positions are padding tokens, + to exclude them from attention calculations. + layer_memory (Optional[KVMemory]): Memory state for storing past key/value pairs for efficient + autoregressive decoding. + + Returns: + DecoderOutput: Named tuple containing output embeddings and updated memory state. + """ def layer_norm(x): return hk_rms_norm(x) @@ -1145,7 +1384,25 @@ class InOutEmbed(hk.Embed): @dataclass class LanguageModelConfig: - """An autoregressive transformer-based language model.""" + """ + Configuration class for an autoregressive language model based on the Transformer architecture. + + Attributes: + model (TransformerConfig): The transformer model configuration. + vocab_size (int): The size of the vocabulary. + pad_token (int): The token used for padding sequences. + eos_token (int): The end-of-sentence token. + sequence_len (int): The maximum sequence length the model can handle. + model_size (int): The dimensionality of the model embeddings. + embedding_init_scale (float): Initial scale for embedding parameter initialization. + embedding_multiplier_scale (float): Multiplier for scaling the embedding vectors. + output_multiplier_scale (float): Multiplier for scaling the output logits. + name (Optional[str]): Name of the language model configuration. + fprop_dtype (Any): Data type for forward propagation computations. + model_type (Optional[str]): Type of the model, if applicable. + init_scale_override (Optional[float]): Override for the initial scale of parameters, if needed. + shard_embeddings (bool): Whether to shard embeddings across the specified axes. + """ model: Optional[TransformerConfig] vocab_size: int @@ -1200,7 +1457,20 @@ def layer_norm(x, model): @dataclass class LanguageModel(hk.Module): - """An autoregressive transformer-based language model.""" + """ + A transformer-based language model for generating or evaluating sequences of tokens. + + This class encapsulates a transformer model and provides methods for its initialization, + running the model forward to generate logits, and handling memory states for efficient + autoregressive generation. + + Attributes: + model (Transformer): The underlying transformer model. + config (LanguageModelConfig): Configuration for the language model. + fprop_dtype (Any): Data type for forward propagation computations. + name (Optional[str]): Optional name for the module. + mesh (Any): The SPMD mesh for parallel computation, if applicable. + """ model: "Transformer" config: LanguageModelConfig @@ -1217,7 +1487,22 @@ class LanguageModel(hk.Module): last_hid_only: bool = False, length: Optional[jax.Array] = None, ) -> LanguageModelOutput: - """Forward pass, producing a sequence of logits.""" + """ + Forward pass, producing a sequence of logits. + Generates logits for the next token predictions based on input tokens and optional memory state. + + Args: + tokens (jax.Array): Input tokens to the language model, shaped as [batch_size, seq_length]. + memory (Optional[Memory]): Optional memory state from previous steps, for autoregressive generation. + batch (Dict[str, jax.Array]): Additional batch information, unused here. + last_hid_only (bool): If True, returns only the last hidden state instead of logits. + length (Optional[jax.Array]): Specifies the length of each sequence in the batch for processing + only up to those lengths. + + Returns: + LanguageModelOutput: A named tuple containing the logits for next token predictions and the + updated memory state. + """ del batch # Unused. config = self.config @@ -1290,7 +1575,30 @@ class LanguageModel(hk.Module): @dataclass class Transformer(hk.Module): - """A transformer stack.""" + """ + Core transformer model class implementing a stack of transformer layers. + + This class is designed to be flexible and configurable, supporting features like + multi-head attention, feed-forward networks, and optional mixture of experts layers. + It is capable of processing sequences of embeddings and returning transformed sequences + of embeddings, along with updated memory states for autoregressive tasks. + + Attributes: + num_q_heads (int): Number of query heads for multi-head attention. + num_kv_heads (int): Number of key/value pairs per attention head. + key_size (int): Size of keys in the attention mechanism. + widening_factor (float): Widening factor for the dimensionality of the feed-forward network. + init_scale (float): Scale for parameter initialization. + mesh (Any): The SPMD mesh for parallel computation. + attn_output_multiplier (float): Multiplier for the output of the attention mechanism. + shard_activations (bool): Whether to shard activations across specified axes. + num_layers (int): Number of transformer layers. + num_experts (int): Number of experts for mixture of experts layers, if applicable. + num_selected_experts (int): Number of experts selected per token, for MoE layers. + name (Optional[str]): Name of the transformer model. + data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding. + model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding. + """ num_q_heads: int num_kv_heads: int @@ -1311,6 +1619,20 @@ class Transformer(hk.Module): model_axis: Union[str, Tuple[str, ...]] = "model" def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16): + """ + Initializes the memory state for the transformer model. + + This is particularly useful for autoregressive tasks where past key and value pairs are cached + to improve efficiency in generating sequences. + + Args: + batch_size (int): The batch size for which to initialize memory states. + sequence_len (int): The sequence length for initializing the size of memory buffers. + dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency. + + Returns: + Memory: A named tuple representing the initialized memory state for each layer. + """ return Memory( layers=init_layer_memories( batch_size, @@ -1329,7 +1651,22 @@ class Transformer(hk.Module): mask: jax.Array, # [B, T] memory: Optional[Memory], ) -> TransformerOutput: - """Transforms input embedding sequences to output embedding sequences.""" + """ + Processes input embeddings through the transformer model. + Transforms input embedding sequences to output embedding sequences. + + Args: + embeddings (jax.Array): Input embeddings to be processed by the transformer, shaped as + [batch_size, seq_length, model_dim]. + mask (jax.Array): Mask indicating valid positions within the input, to control which positions + are allowed to attend to each other, shaped as [batch_size, seq_length]. + memory (Optional[Memory]): Optional memory state for the transformer to support autoregressive + decoding or similar use cases. + + Returns: + TransformerOutput: A named tuple containing the transformed embeddings and the final state + of the memory after processing. + """ fprop_dtype = embeddings.dtype _, seq_len, model_size = embeddings.shape