mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-23 03:59:53 +03:00
Updated docstring and description of modeling process
This commit is contained in:
parent
88da8c077a
commit
fb9d433aa9
500
model.py
500
model.py
@ -12,6 +12,70 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Grok-1
|
||||
|
||||
This architecture is designed for language modeling and autoregressive sequence generation, incorporating advanced features such as:
|
||||
|
||||
- 8-bit Quantized Weights: Implements quantization for model parameters to reduce the memory footprint and potentially increase computational efficiency.
|
||||
- Sharding and Distributed Computation: Utilizes JAX's capabilities for distributed computation across devices, optimizing parallel processing and memory usage.
|
||||
- Memory Caching for Autoregressive Decoding: Features a mechanism for caching keys and values in attention layers, enhancing efficiency in sequence generation tasks.
|
||||
|
||||
Core Components:
|
||||
- Multi-Head Attention (MHA): Custom implementation, including rotary positional embeddings for implicit sequence position information.
|
||||
- Mixture of Experts (MoE): Allows routing inputs to different "expert" networks based on the input data, increasing model capacity and expressiveness.
|
||||
- Feed-Forward Networks (FFNs): Defines networks with adjustable sizes and support for quantized weights and custom linear layers with sharding.
|
||||
|
||||
Training and Inference Workflow:
|
||||
- Manages the model's parameters, caching mechanisms, and layer-wise states efficiently during both training and inference.
|
||||
- Implements detailed sharding strategies for optimizing the distribution of computation and storage across multiple devices.
|
||||
|
||||
Advanced Features:
|
||||
- Custom Layer Normalization: Adopts RMSNorm for stabilizing the training of deep networks.
|
||||
- Dynamic and Static Sharding: Offers flexibility in data and model parallelism, allowing dynamic adjustment of sharding constraints.
|
||||
|
||||
Efficiency and Scalability:
|
||||
- Efficiently manages data flow and computation, minimizing unnecessary data replication and movement across devices.
|
||||
- Designed with scalability in mind, providing a foundation for training complex models on massive datasets.
|
||||
|
||||
This architecture leverages JAX for high-performance numerical computing and automatic differentiation, alongside Haiku for modular and flexible deep learning model construction.
|
||||
|
||||
Data Flow Through the Training Process:
|
||||
|
||||
1. Input Preparation:
|
||||
- The process begins with the preparation of input data, which typically involves tokenizing text data into numerical tokens that represent words or subwords in a vocabulary.
|
||||
- Tokens are then batched and padded to ensure consistent sequence lengths within each batch, forming the initial input tensor for the model.
|
||||
|
||||
2. Embedding Layer:
|
||||
- The input tokens are passed through an embedding layer, transforming each token into a high-dimensional vector. This layer may utilize pre-trained embeddings or learn embeddings during training.
|
||||
- Positional encodings or embeddings are added to these vectors to incorporate sequence position information.
|
||||
|
||||
3. Transformer Layers:
|
||||
- The sequence of embedding vectors is processed through multiple Transformer layers, each consisting of the following sub-layers:
|
||||
a. Multi-Head Attention (MHA): Each layer computes self-attention for its input, allowing each position in the sequence to attend to all positions in the previous layer’s output.
|
||||
b. Feed-Forward Network (FFN): After attention, the output for each position passes through a feed-forward network. FFNs are identical for different positions but have different parameters from layer to layer.
|
||||
|
||||
- Between each sub-layer, residual connections followed by layer normalization are applied. This helps in stabilizing the training of deep networks.
|
||||
|
||||
4. Caching and Memory:
|
||||
- For autoregressive tasks, where the model generates sequences one token at a time, keys and values computed during the attention operations are cached. This mechanism allows reusing these computations in subsequent steps, reducing the computational load.
|
||||
|
||||
5. Output Layer:
|
||||
- The output from the final Transformer layer is passed through a linear layer or a decoder, transforming the high-dimensional representations back into the vocabulary space to produce logits for each token in the sequence.
|
||||
|
||||
6. Loss Calculation and Backpropagation:
|
||||
- The logits are compared against the true token sequences using a suitable loss function (e.g., cross-entropy for language modeling tasks). The loss quantifies the model's prediction accuracy.
|
||||
- Based on the loss, gradients are computed for each parameter in the model using backpropagation. These gradients indicate how each parameter should be adjusted to minimize the loss.
|
||||
|
||||
7. Parameter Update:
|
||||
- Model parameters are updated using an optimization algorithm (e.g., Adam, SGD) and the computed gradients. This step adjusts the model’s weights to improve its predictions on the next iteration.
|
||||
|
||||
8. Iteration and Convergence:
|
||||
- Steps 1 through 7 are repeated for multiple epochs over the training dataset until the model's performance on a validation set converges or begins to degrade, indicating potential overfitting.
|
||||
|
||||
This structured approach, combined with the Transformer architecture's capability to model complex dependencies in data, enables effective training of models for tasks such as language understanding, translation, and generation.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import re
|
||||
@ -70,14 +134,16 @@ class TrainingState(NamedTuple):
|
||||
|
||||
def _match(qs, ks):
|
||||
"""
|
||||
Checks if regex patterns in qs match any contiguous subsequence of strings in ks.
|
||||
Determines if a sequence of regex patterns (qs) matches any contiguous subsequence of strings (ks).
|
||||
This utility function is often used for matching parameter names or paths in a hierarchical structure.
|
||||
|
||||
Args:
|
||||
qs (Sequence[str]): A sequence of regex patterns to match.
|
||||
ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched.
|
||||
|
||||
Returns:
|
||||
bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise.
|
||||
bool: True if every pattern in qs has a corresponding match in a contiguous subsequence of ks,
|
||||
otherwise False.
|
||||
"""
|
||||
# compile regexes and force complete match
|
||||
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
||||
@ -90,14 +156,17 @@ def _match(qs, ks):
|
||||
|
||||
def with_sharding_constraint(x, constraint):
|
||||
"""
|
||||
Applies a sharding constraint to a JAX array, if a physical mesh is available.
|
||||
Applies a sharding constraint to a JAX array. This function is used in SPMD programs to hint how the
|
||||
data should be partitioned across devices. If a physical mesh is not available, it simply returns the
|
||||
original array.
|
||||
|
||||
Args:
|
||||
x (jax.Array): The array to apply the sharding constraint to.
|
||||
constraint (PartitionSpec): The sharding constraint to apply.
|
||||
|
||||
Returns:
|
||||
jax.Array: The array with the sharding constraint applied if a physical mesh is present.
|
||||
jax.Array: The array with the sharding constraint applied, affecting its distribution across devices
|
||||
in distributed computation setups.
|
||||
"""
|
||||
if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
|
||||
return x
|
||||
@ -107,13 +176,15 @@ def with_sharding_constraint(x, constraint):
|
||||
|
||||
def cast_bfloat16(x):
|
||||
"""
|
||||
Casts the input to bfloat16 type if it is of floating-point type.
|
||||
Casts the input array to bfloat16 type if it is of floating-point type. This operation is often used to
|
||||
reduce memory consumption and potentially increase computation speed by using lower precision.
|
||||
|
||||
Args:
|
||||
x (jax.Array): The input array.
|
||||
|
||||
Returns:
|
||||
jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise.
|
||||
jax.Array: The array cast to bfloat16 if the original array was floating-point; otherwise, the array
|
||||
is returned unchanged.
|
||||
"""
|
||||
if x.dtype.kind == "f":
|
||||
return x.astype(jnp.bfloat16)
|
||||
@ -241,6 +312,14 @@ TOP_K = 8
|
||||
|
||||
|
||||
class KVMemory(NamedTuple):
|
||||
"""
|
||||
Represents key-value memory slots for a transformer layer, supporting efficient autoregressive decoding by caching past computed keys and values.
|
||||
|
||||
Attributes:
|
||||
k (Optional[jax.Array]): Cached keys, shaped as [batch_size, sequence_len, num_kv_heads, key_size].
|
||||
v (Optional[jax.Array]): Cached values, shaped as [batch_size, sequence_len, num_kv_heads, key_size].
|
||||
step (Optional[jax.Array]): The current decoding step, indicating how many positions have been generated, shaped as [batch_size].
|
||||
"""
|
||||
k: Optional[jax.Array]
|
||||
v: Optional[jax.Array]
|
||||
step: Optional[jax.Array]
|
||||
@ -256,19 +335,19 @@ def init_layer_memories(
|
||||
dtype=jnp.bfloat16,
|
||||
):
|
||||
"""
|
||||
Initializes memory slots for each layer in the transformer model.
|
||||
Initializes layer memories for each transformer layer, providing a mechanism for efficient sequence generation by caching keys and values.
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batch.
|
||||
sequence_len (int): The length of the sequence.
|
||||
num_kv_heads (int): The number of key/value pairs per head.
|
||||
key_size (int): The size of each key.
|
||||
num_layers (int): The number of layers in the transformer.
|
||||
step (Optional[jax.Array]): The initial step for the memory, defaults to None.
|
||||
dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16.
|
||||
batch_size (int): The number of sequences being processed in parallel.
|
||||
sequence_len (int): The length of the sequences for which memory is allocated.
|
||||
num_kv_heads (int): The number of key-value pairs per head in the attention mechanism.
|
||||
key_size (int): The size of each key (and value) in the attention mechanism.
|
||||
num_layers (int): The number of transformer layers for which memory is initialized.
|
||||
step (Optional[jax.Array]): The initial decoding step for each sequence in the batch. Defaults to None, indicating no prior steps.
|
||||
dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency.
|
||||
|
||||
Returns:
|
||||
List[KVMemory]: A list of initialized KVMemory instances for each layer.
|
||||
List[KVMemory]: A list of initialized KVMemory instances for each layer in the transformer model.
|
||||
"""
|
||||
return [
|
||||
KVMemory(
|
||||
@ -281,6 +360,12 @@ def init_layer_memories(
|
||||
|
||||
|
||||
class Memory(NamedTuple):
|
||||
"""
|
||||
A named tuple representing the complete memory state of a transformer model, encapsulating key-value memory slots for all layers.
|
||||
|
||||
Attributes:
|
||||
layers (List[KVMemory]): A list of KVMemory instances, one for each layer in the transformer model.
|
||||
"""
|
||||
# Self-attention key/value cache.
|
||||
layers: List[KVMemory]
|
||||
|
||||
@ -306,6 +391,17 @@ class Router(hk.Module):
|
||||
mesh: Any = None,
|
||||
name: str = "router",
|
||||
):
|
||||
"""
|
||||
Initializes a router for directing inputs to experts in a Mixture of Experts (MoE) layer.
|
||||
|
||||
Args:
|
||||
num_selected_experts (int): The number of experts to select for each input.
|
||||
data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded.
|
||||
model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded.
|
||||
shard_activations (bool): Indicates whether to shard activations according to the data and model axes.
|
||||
mesh (Any): The SPMD mesh object for parallel computation.
|
||||
name (str): An optional name for the module.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.shard_activations = shard_activations
|
||||
self.data_axis = data_axis
|
||||
@ -316,6 +412,20 @@ class Router(hk.Module):
|
||||
def compute_routing_prob(
|
||||
self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int
|
||||
):
|
||||
"""
|
||||
Computes the routing probabilities for each input to be directed to each expert.
|
||||
|
||||
This internal method calculates the logits that determine how inputs are distributed among the experts,
|
||||
based on learned criteria.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input data to be routed.
|
||||
padding_mask (Optional[jax.Array]): Mask indicating which inputs are padding and should not be considered.
|
||||
num_experts (int): The total number of experts available.
|
||||
|
||||
Returns:
|
||||
A tuple containing the routing probabilities, logits, and a dummy placeholder for compatibility.
|
||||
"""
|
||||
return self._compute_routing_prob(inputs, padding_mask, num_experts)
|
||||
|
||||
@hk.transparent
|
||||
@ -398,6 +508,19 @@ class MoELayer(hk.Module):
|
||||
model_axis: Union[str, Tuple[str, ...]] = "model",
|
||||
name: Optional[str] = "moe",
|
||||
):
|
||||
"""
|
||||
Initializes a Mixture of Experts layer with specified configuration.
|
||||
|
||||
Args:
|
||||
num_experts (int): The total number of experts in the MoE layer.
|
||||
layer_fn (Callable): The function defining the computation performed by each expert.
|
||||
router (Router): The router that directs inputs to selected experts.
|
||||
mesh (Any): The optional SPMD mesh for parallel computation.
|
||||
shard_activations (bool): Whether to shard activations across distributed resources.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Specifies how data is sharded for distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Specifies how model parameters are sharded.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.num_experts = num_experts
|
||||
self.layer_fn = layer_fn
|
||||
@ -544,11 +667,11 @@ class MHAOutput(NamedTuple):
|
||||
|
||||
class DecoderOutput(NamedTuple):
|
||||
"""
|
||||
Encapsulates the output of a decoder layer within the transformer model.
|
||||
Encapsulates the output from a decoder layer within the transformer model, including the transformed embeddings and any updated memory state.
|
||||
|
||||
Attributes:
|
||||
embeddings (jax.Array): The embeddings produced by the decoder layer.
|
||||
memory (Any): The updated memory state after processing by the decoder layer.
|
||||
embeddings (jax.Array): The embeddings produced by the decoder layer, shaped as [batch_size, seq_length, embedding_dim].
|
||||
memory (Any): The updated memory state after processing by the decoder layer, useful for autoregressive decoding tasks.
|
||||
"""
|
||||
embeddings: jax.Array
|
||||
memory: Any
|
||||
@ -556,11 +679,11 @@ class DecoderOutput(NamedTuple):
|
||||
|
||||
class TransformerOutput(NamedTuple):
|
||||
"""
|
||||
Represents the final output of the transformer model.
|
||||
Represents the final output from the transformer model, including the final set of embeddings and any memory states that have been updated through the model's layers.
|
||||
|
||||
Attributes:
|
||||
embeddings (jax.Array): The final output embeddings from the transformer.
|
||||
memory (Any): The final state of the memory after passing through the transformer layers.
|
||||
embeddings (jax.Array): The final output embeddings from the transformer, shaped as [batch_size, seq_length, embedding_dim].
|
||||
memory (Any): The final memory state of the model after all transformer layers have been applied.
|
||||
"""
|
||||
embeddings: jax.Array
|
||||
memory: Any
|
||||
@ -569,25 +692,26 @@ class TransformerOutput(NamedTuple):
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""
|
||||
Configuration class for a Transformer model specifying the model's architecture and settings.
|
||||
Configuration class for setting up a Transformer model's architecture and its specific parameters.
|
||||
|
||||
This class defines key architectural features of the transformer, including the size of embeddings,
|
||||
the dimensionality of keys and values in the attention mechanism, the number of layers, and more.
|
||||
It also includes configurations for advanced features like Mixture of Experts (MoE) and activation sharding.
|
||||
|
||||
Attributes:
|
||||
emb_size (int): Embedding size used in the transformer.
|
||||
key_size (int): Size of the keys in the multi-head attention mechanism.
|
||||
num_q_heads (int): Number of query heads in the multi-head attention.
|
||||
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||
num_layers (int): Number of layers in the transformer model.
|
||||
vocab_size (int): The size of the vocabulary.
|
||||
widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size.
|
||||
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||
name (Optional[str]): Name of the transformer configuration.
|
||||
num_experts (int): Number of experts in a mixture of experts layer.
|
||||
capacity_factor (float): Capacity factor for routing in MoE layers.
|
||||
num_selected_experts (int): Number of experts selected in each MoE layer.
|
||||
init_scale (float): Initial scale for parameter initialization.
|
||||
shard_activations (bool): If True, activations will be sharded across the specified axes.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded.
|
||||
emb_size (int): The size of the embedding vectors.
|
||||
key_size (int): The size of the key (and query) vectors in the attention mechanism.
|
||||
num_q_heads (int): The number of heads in the query part of the multi-head attention mechanism.
|
||||
num_kv_heads (int): The number of heads for keys and values in the multi-head attention.
|
||||
num_layers (int): The total number of layers in the transformer model.
|
||||
vocab_size (int): The size of the vocabulary that the model can understand.
|
||||
widening_factor (float): The factor by which the dimensionality of the feed-forward networks is widened relative to the embedding size.
|
||||
attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism, for controlling its magnitude.
|
||||
shard_activations (bool): Whether to shard activations across devices for parallel processing.
|
||||
num_experts (int): The number of experts in the Mixture of Experts (MoE) layer, if used.
|
||||
num_selected_experts (int): The number of experts selected for each input in the MoE layer.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which data is sharded for distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Specifies the axis names over which model parameters are sharded for distributed computation.
|
||||
"""
|
||||
emb_size: int
|
||||
key_size: int
|
||||
@ -673,6 +797,10 @@ def make_attention_mask(
|
||||
dtype: Any = jnp.bfloat16,
|
||||
):
|
||||
"""Mask-making helper for attention weights.
|
||||
|
||||
Creates an attention mask to specify which tokens in the key sequences can be attended to by each token in the query sequences.
|
||||
|
||||
This utility is used in attention mechanisms to control the visibility of tokens, for purposes such as preventing future tokens from being attended to in autoregressive models.
|
||||
|
||||
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
|
||||
attention weights will be `[batch..., heads, len_q, len_kv]` and this
|
||||
@ -729,7 +857,17 @@ class Linear(hk.Linear):
|
||||
self,
|
||||
inputs: jax.Array,
|
||||
) -> jax.Array:
|
||||
"""Computes a linear transform of the input."""
|
||||
"""
|
||||
Computes a linear transform of the input data.
|
||||
|
||||
This method computes the matrix multiplication between the inputs and the layer's weight matrix, optionally adding a bias term.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): The input tensor to be transformed, shaped as [batch_size, ..., input_features].
|
||||
|
||||
Returns:
|
||||
jax.Array: The transformed tensor, shaped as [batch_size, ..., output_features].
|
||||
"""
|
||||
|
||||
fprop_dtype = inputs.dtype
|
||||
if not inputs.shape:
|
||||
@ -796,6 +934,17 @@ class RMSNorm(hk.RMSNorm):
|
||||
self.sharding = sharding
|
||||
|
||||
def __call__(self, inputs: jax.Array):
|
||||
"""
|
||||
Applies RMS normalization to the input tensor.
|
||||
|
||||
This method normalizes the inputs by their root mean square value, optionally scaling the result by a learnable parameter to adjust the representation scale.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): The input tensor to be normalized, shaped as [batch_size, ..., features].
|
||||
|
||||
Returns:
|
||||
jax.Array: The RMS-normalized tensor, maintaining the input shape.
|
||||
"""
|
||||
fprop_dtype = inputs.dtype
|
||||
param_shape = (inputs.shape[-1],)
|
||||
if self.create_scale:
|
||||
@ -865,6 +1014,21 @@ class RotaryEmbedding(hk.Module):
|
||||
const_position: Optional[int] = None,
|
||||
t: Optional[jax.Array] = None,
|
||||
) -> jax.Array:
|
||||
"""
|
||||
Applies Rotary Position Embedding (RoPE) to the input embeddings.
|
||||
|
||||
This method dynamically encodes positional information by applying a rotation based on the position in the sequence, enhancing the model's ability to capture sequential dependencies.
|
||||
|
||||
Args:
|
||||
x (jax.Array): Input embeddings to apply RoPE to, shaped as [batch_size, seq_length, embedding_dim].
|
||||
seq_dim (int): The dimension index of the sequence length in the input tensor.
|
||||
offset (jax.Array): The offset to apply to the position indices, useful for continuation of sequences across batches.
|
||||
const_position (Optional[int]): A constant position value to use for all positions, instead of a range.
|
||||
t (Optional[jax.Array]): Explicit tensor of position values, shaped as [seq_length].
|
||||
|
||||
Returns:
|
||||
jax.Array: The input embeddings with RoPE applied, maintaining the input shape.
|
||||
"""
|
||||
fprop_dtype = x.dtype
|
||||
# Compute the per-dimension frequencies
|
||||
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
|
||||
@ -897,6 +1061,23 @@ class RotaryEmbedding(hk.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(hk.Module):
|
||||
"""
|
||||
Implements the Multi-Head Attention mechanism, a key component of Transformer architectures.
|
||||
|
||||
This module allows each token in the input sequence to attend to all tokens in the key and value sequences, with multiple "heads" learning different attention patterns.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): The number of heads for the queries.
|
||||
num_kv_heads (int): The number of heads for the keys and values.
|
||||
key_size (int): The size of the key vectors.
|
||||
with_bias (bool): Whether to include bias terms in the linear transformations.
|
||||
value_size (Optional[int]): The size of the value vectors, defaults to the key size if not specified.
|
||||
model_size (Optional[int]): The size of the output dimension from the attention mechanism, defaults to the total size of all heads if not specified.
|
||||
attn_output_multiplier (float): A scaling factor applied to the output of the attention mechanism.
|
||||
data_axis (Union[str, Tuple[str, ...]]): The axis names over which data is sharded for distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): The axis names over which model parameters are sharded for distributed computation.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_q_heads: int,
|
||||
@ -911,6 +1092,21 @@ class MultiHeadAttention(hk.Module):
|
||||
model_axis: Union[str, Tuple[str, ...]] = "model",
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the Multi-Head Attention module with the provided configuration parameters.
|
||||
|
||||
Args:
|
||||
num_q_heads (int): Number of query heads for the multi-head attention mechanism.
|
||||
num_kv_heads (int): Number of key/value heads for the multi-head attention mechanism.
|
||||
key_size (int): Dimensionality of key vectors in the attention mechanism.
|
||||
with_bias (bool): Whether to include a bias term in the attention weight computation.
|
||||
value_size (Optional[int]): Dimensionality of value vectors, defaults to key size if not specified.
|
||||
model_size (Optional[int]): Overall size of the model's output dimension.
|
||||
attn_output_multiplier (float): Multiplier for scaling the output of the attention mechanism.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Data sharding axis names for distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Model parameter sharding axis names for distributed computation.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self.num_q_heads = num_q_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -1121,6 +1317,22 @@ class MultiHeadAttention(hk.Module):
|
||||
name: Optional[str] = None,
|
||||
mesh: Any = None,
|
||||
) -> jax.Array:
|
||||
"""
|
||||
Projects input embeddings into multiple head spaces for queries, keys, or values.
|
||||
|
||||
This internal method creates separate embeddings for each attention head by applying a linear transformation,
|
||||
allowing the multi-head attention mechanism to explore different subspace relationships in the data.
|
||||
|
||||
Args:
|
||||
x (jax.Array): Input tensor to project, shaped as [batch_size, seq_length, embedding_dim].
|
||||
head_size (int): The dimensionality of each head's subspace.
|
||||
num_heads (int): The number of heads to project into.
|
||||
name (Optional[str]): A name for the operation, distinguishing between queries, keys, and values.
|
||||
sharding (Optional[PartitionSpec]): The sharding specification for distributing computation across devices.
|
||||
|
||||
Returns:
|
||||
jax.Array: Projected embeddings for multiple heads, shaped as [batch_size, seq_length, num_heads, head_size].
|
||||
"""
|
||||
y = Linear(
|
||||
num_heads * head_size,
|
||||
with_bias=False,
|
||||
@ -1134,7 +1346,20 @@ class MultiHeadAttention(hk.Module):
|
||||
|
||||
@dataclass
|
||||
class MHABlock(hk.Module):
|
||||
"""A MHA Block"""
|
||||
"""
|
||||
A specialized module encapsulating the Multi-Head Attention (MHA) operation within a transformer model.
|
||||
|
||||
This module orchestrates the application of MHA, including the computation of queries, keys, and values, and the subsequent attention and aggregation operations.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): Number of heads for the query part of the MHA.
|
||||
num_kv_heads (int): Number of heads for the key and value parts of the MHA.
|
||||
key_size (int): Size of the keys (and queries) in the attention mechanism.
|
||||
attn_output_multiplier (float): Scaling factor applied to the output of the attention mechanism.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded for distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded for distributed computation.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
"""
|
||||
|
||||
num_q_heads: int
|
||||
num_kv_heads: int
|
||||
@ -1151,6 +1376,20 @@ class MHABlock(hk.Module):
|
||||
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
|
||||
layer_memory: Optional[KVMemory],
|
||||
) -> MHAOutput:
|
||||
"""
|
||||
Processes inputs through a Multi-Head Attention block.
|
||||
|
||||
This method applies multi-head attention to the inputs using the provided mask for attention scoring,
|
||||
and optionally utilizes cached memory for keys and values to enhance efficiency in autoregressive models.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input embeddings, shaped as [batch_size, seq_length, embedding_dim].
|
||||
mask (jax.Array): Attention mask, shaped as [batch_size, 1, seq_length, seq_length], to control visibility between tokens.
|
||||
layer_memory (Optional[KVMemory]): Cached keys and values from previous steps for efficient attention computation in autoregressive decoding.
|
||||
|
||||
Returns:
|
||||
MHAOutput: The output from the multi-head attention block, including the transformed embeddings and updated memory.
|
||||
"""
|
||||
_, _, model_size = inputs.shape
|
||||
assert mask.ndim == 4, f"shape: {mask.shape}"
|
||||
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
|
||||
@ -1183,6 +1422,19 @@ class MHABlock(hk.Module):
|
||||
|
||||
@dataclass
|
||||
class DenseBlock(hk.Module):
|
||||
"""
|
||||
Implements a dense (fully connected) block within a transformer layer, typically following multi-head attention.
|
||||
|
||||
This block applies one or more linear transformations to its inputs, often including non-linear activations and potentially other operations like dropout for regularization.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): The number of heads for the query in the preceding MHA layer.
|
||||
num_kv_heads (int): The number of key/value pairs per head in the MHA layer.
|
||||
key_size (int): The size of the keys in the MHA layer.
|
||||
widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased.
|
||||
sharding_constraint (bool): Whether to apply a sharding constraint for distributed computation.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
"""
|
||||
num_q_heads: int
|
||||
num_kv_heads: int
|
||||
key_size: int
|
||||
@ -1195,6 +1447,18 @@ class DenseBlock(hk.Module):
|
||||
self,
|
||||
inputs: jax.Array, # [B, T, D]
|
||||
) -> jax.Array: # [B, T, D]
|
||||
"""
|
||||
Applies a series of dense transformations to the inputs.
|
||||
|
||||
This method constitutes the feedforward network part of a transformer layer, applying linear transformations
|
||||
followed by non-linear activations to model complex data relationships beyond what attention mechanisms capture.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input embeddings from the previous layer or block, shaped as [batch_size, seq_length, embedding_dim].
|
||||
|
||||
Returns:
|
||||
jax.Array: The output embeddings after applying the dense block transformations, maintaining the input shape.
|
||||
"""
|
||||
_, _, model_size = inputs.shape
|
||||
h_v = Linear(
|
||||
ffn_size(
|
||||
@ -1230,7 +1494,28 @@ class DenseBlock(hk.Module):
|
||||
|
||||
@dataclass
|
||||
class DecoderLayer(hk.Module):
|
||||
"""A transformer stack."""
|
||||
"""
|
||||
Represents a single layer in the decoder stack of a transformer model. This layer processes input embeddings through
|
||||
a multi-head attention mechanism followed by position-wise feed-forward networks, with normalization and skip connections
|
||||
applied as per the standard transformer architecture.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): Number of query heads for multi-head attention.
|
||||
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||
key_size (int): Size of keys in the attention mechanism.
|
||||
num_layers (int): Total number of transformer layers.
|
||||
num_experts (int): Number of experts in the Mixture of Experts layer, if used.
|
||||
layer_index (Optional[int]): Index of this layer within the overall model; used for layer-specific configurations or logging.
|
||||
num_selected_experts (int): Number of experts selected for each input in the MoE layer.
|
||||
widening_factor (float): Factor by which the dimensionality of the feed-forward network is increased.
|
||||
name (Optional[str]): An optional name for the layer.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding in distributed computation.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding in distributed computation.
|
||||
shard_activations (bool): Whether activations should be sharded across devices.
|
||||
attn_output_multiplier (float): Scaling factor for the output of the attention mechanism.
|
||||
mesh (Any): SPMD mesh for parallel computation.
|
||||
"""
|
||||
|
||||
|
||||
num_q_heads: int
|
||||
num_kv_heads: int
|
||||
@ -1342,12 +1627,35 @@ class DecoderLayer(hk.Module):
|
||||
|
||||
|
||||
class LanguageModelOutput(NamedTuple):
|
||||
"""
|
||||
Represents the output of the language model after processing an input sequence of tokens.
|
||||
|
||||
This output encapsulates both the logits representing the model's predictions for the next token
|
||||
in the sequence and the updated model state, which includes any memory or state information
|
||||
that needs to be carried over for generating subsequent tokens.
|
||||
|
||||
Attributes:
|
||||
logits (jax.Array): The logits for the next token predictions, shaped as [batch_size, sequence_length, vocab_size].
|
||||
model_state (Any): The updated state of the model after processing the input sequence, which may include
|
||||
memory states for layers that utilize recurrence or caching for efficiency.
|
||||
"""
|
||||
logits: jax.Array
|
||||
model_state: Any
|
||||
|
||||
|
||||
class InOutEmbed(hk.Embed):
|
||||
"""Module for embedding tokens in a low-dimensional space."""
|
||||
"""
|
||||
A module for embedding input tokens into a low-dimensional space continuous vector space and for projecting the outputs of
|
||||
a transformer back into the vocabulary space. This module supports tying the weights between the input
|
||||
embedding and the output projection for parameter efficiency.
|
||||
|
||||
Attributes:
|
||||
vocab_size (Optional[int]): The size of the vocabulary.
|
||||
embed_dim (Optional[int]): The dimensionality of the embedding vectors.
|
||||
sharding (Optional[PartitionSpec]): Specifies how the embedding parameters should be sharded across
|
||||
devices for distributed computation.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1356,6 +1664,17 @@ class InOutEmbed(hk.Embed):
|
||||
sharding: Optional[P] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes an embedding module that can be used for both input token embeddings and output logits projection in a transformer model.
|
||||
|
||||
This shared embedding layer helps reduce the number of parameters by tying the weights between the input embedding and the output projection layers.
|
||||
|
||||
Args:
|
||||
vocab_size (Optional[int]): The size of the vocabulary.
|
||||
embed_dim (Optional[int]): The dimensionality of the embedding vectors.
|
||||
sharding (Optional[PartitionSpec]): The sharding specification for distributing the embedding parameters across devices.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
@ -1365,6 +1684,15 @@ class InOutEmbed(hk.Embed):
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""
|
||||
Retrieves the embedding matrix from the module's parameters.
|
||||
|
||||
This method is useful for operations that need direct access to the embeddings, such as output projection in language models.
|
||||
|
||||
Returns:
|
||||
jax.Array: The embedding matrix, shaped as [vocab_size, embed_dim].
|
||||
"""
|
||||
|
||||
embed_mat = hk.get_parameter(
|
||||
"embeddings",
|
||||
[self.vocab_size, self.embed_dim],
|
||||
@ -1379,6 +1707,17 @@ class InOutEmbed(hk.Embed):
|
||||
self,
|
||||
inputs: jax.Array,
|
||||
) -> jax.Array:
|
||||
"""
|
||||
Projects transformer model outputs back into the vocabulary space using the transposed embedding matrix.
|
||||
|
||||
This method effectively performs the inverse of the embedding operation, converting model output embeddings into logits over the vocabulary.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): The output embeddings from the transformer model, shaped as [batch_size, seq_length, embed_dim].
|
||||
|
||||
Returns:
|
||||
jax.Array: The logits over the vocabulary for each token position, shaped as [batch_size, seq_length, vocab_size].
|
||||
"""
|
||||
return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype))
|
||||
|
||||
|
||||
@ -1458,18 +1797,21 @@ def layer_norm(x, model):
|
||||
@dataclass
|
||||
class LanguageModel(hk.Module):
|
||||
"""
|
||||
A transformer-based language model for generating or evaluating sequences of tokens.
|
||||
A high-level module for autoregressive language modeling using a Transformer architecture. This module
|
||||
integrates components such as embedding layers, transformer blocks, and output layers to process sequences
|
||||
of tokens and generate predictions for the next tokens in the sequence.
|
||||
|
||||
This class encapsulates a transformer model and provides methods for its initialization,
|
||||
running the model forward to generate logits, and handling memory states for efficient
|
||||
autoregressive generation.
|
||||
The LanguageModel is designed for tasks such as text generation, where it can be used to produce coherent
|
||||
and contextually relevant text based on a given prompt.
|
||||
|
||||
Attributes:
|
||||
model (Transformer): The underlying transformer model.
|
||||
config (LanguageModelConfig): Configuration for the language model.
|
||||
fprop_dtype (Any): Data type for forward propagation computations.
|
||||
name (Optional[str]): Optional name for the module.
|
||||
mesh (Any): The SPMD mesh for parallel computation, if applicable.
|
||||
model (Transformer): The core transformer model used for processing input token sequences.
|
||||
config (LanguageModelConfig): Configuration parameters for the language model, including details about
|
||||
the architecture, embeddings, and output processing.
|
||||
fprop_dtype (Any): The data type to use for forward propagation calculations, typically set to jnp.bfloat16
|
||||
for efficiency.
|
||||
name (Optional[str]): An optional name for the module. Useful for distinguishing between multiple instances.
|
||||
mesh (Any): The SPMD mesh for parallel computation, supporting distributed training and inference.
|
||||
"""
|
||||
|
||||
model: "Transformer"
|
||||
@ -1564,9 +1906,31 @@ class LanguageModel(hk.Module):
|
||||
)
|
||||
|
||||
def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16):
|
||||
"""
|
||||
Initializes the memory for the language model, suitable for autoregressive sequence generation tasks.
|
||||
|
||||
Args:
|
||||
batch_size (int): The number of sequences for which to initialize memory.
|
||||
seq_len (int): The length of sequences for memory allocation.
|
||||
dtype (Any): Data type for the memory arrays, typically jnp.bfloat16.
|
||||
|
||||
Returns:
|
||||
Memory: The initialized memory structure for storing keys, values, and decoding steps across transformer layers.
|
||||
"""
|
||||
return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype)
|
||||
|
||||
def prefill_memory(self, prompts, memory):
|
||||
"""
|
||||
Optionally pre-fills the transformer's memory with information from a provided prompt, enhancing efficiency
|
||||
in subsequent autoregressive generation steps by caching necessary computations from the prompt processing.
|
||||
|
||||
Args:
|
||||
prompts (jax.Array): The prompt tokens from which to generate subsequent tokens.
|
||||
memory (Memory): The memory state to update with information derived from the prompts.
|
||||
|
||||
Returns:
|
||||
Tuple[jax.Array, Memory]: The logits produced by processing the prompts and the updated memory state.
|
||||
"""
|
||||
# Pad to the left and right align?
|
||||
# Basically assume prompt is already padded
|
||||
model_output = self(prompts, memory=memory)
|
||||
@ -1576,28 +1940,26 @@ class LanguageModel(hk.Module):
|
||||
@dataclass
|
||||
class Transformer(hk.Module):
|
||||
"""
|
||||
Core transformer model class implementing a stack of transformer layers.
|
||||
|
||||
This class is designed to be flexible and configurable, supporting features like
|
||||
multi-head attention, feed-forward networks, and optional mixture of experts layers.
|
||||
It is capable of processing sequences of embeddings and returning transformed sequences
|
||||
of embeddings, along with updated memory states for autoregressive tasks.
|
||||
Core transformer module that implements the foundational architecture of a transformer-based model. This module
|
||||
is capable of processing sequences of embeddings through multiple layers of self-attention and feed-forward
|
||||
networks, optionally including advanced techniques like mixture of experts (MoE) and activation sharding
|
||||
for efficient large-scale parallel computation.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): Number of query heads for multi-head attention.
|
||||
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||
key_size (int): Size of keys in the attention mechanism.
|
||||
widening_factor (float): Widening factor for the dimensionality of the feed-forward network.
|
||||
init_scale (float): Scale for parameter initialization.
|
||||
num_q_heads (int): Number of heads in the query part of the multi-head attention mechanism.
|
||||
num_kv_heads (int): Number of heads for the keys and values in the multi-head attention.
|
||||
key_size (int): Dimensionality of the key (and query) vectors in the attention mechanism.
|
||||
widening_factor (float): Factor by which to widen the dimensionality of the feed-forward network relative to the embeddings.
|
||||
init_scale (float): Initial scale for parameter initialization.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||
shard_activations (bool): Whether to shard activations across specified axes.
|
||||
num_layers (int): Number of transformer layers.
|
||||
num_experts (int): Number of experts for mixture of experts layers, if applicable.
|
||||
num_selected_experts (int): Number of experts selected per token, for MoE layers.
|
||||
name (Optional[str]): Name of the transformer model.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding.
|
||||
shard_activations (bool): Whether to shard activations across devices in distributed settings.
|
||||
num_layers (int): Number of transformer layers to stack in the model.
|
||||
num_experts (int): Number of experts in the MoE layer, if used.
|
||||
num_selected_experts (int): Number of experts selected for each input token in the MoE layer.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names for sharding data across devices.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names for sharding model parameters across devices.
|
||||
name (Optional[str]): An optional name for the module.
|
||||
"""
|
||||
|
||||
num_q_heads: int
|
||||
|
Loading…
Reference in New Issue
Block a user