2024-03-15 01:03:58 +03:00
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-03-19 00:53:06 +03:00
"""
Grok - 1
This architecture is designed for language modeling and autoregressive sequence generation , incorporating advanced features such as :
- 8 - bit Quantized Weights : Implements quantization for model parameters to reduce the memory footprint and potentially increase computational efficiency .
- Sharding and Distributed Computation : Utilizes JAX ' s capabilities for distributed computation across devices, optimizing parallel processing and memory usage.
- Memory Caching for Autoregressive Decoding : Features a mechanism for caching keys and values in attention layers , enhancing efficiency in sequence generation tasks .
Core Components :
- Multi - Head Attention ( MHA ) : Custom implementation , including rotary positional embeddings for implicit sequence position information .
- Mixture of Experts ( MoE ) : Allows routing inputs to different " expert " networks based on the input data , increasing model capacity and expressiveness .
- Feed - Forward Networks ( FFNs ) : Defines networks with adjustable sizes and support for quantized weights and custom linear layers with sharding .
Training and Inference Workflow :
- Manages the model ' s parameters, caching mechanisms, and layer-wise states efficiently during both training and inference.
- Implements detailed sharding strategies for optimizing the distribution of computation and storage across multiple devices .
Advanced Features :
- Custom Layer Normalization : Adopts RMSNorm for stabilizing the training of deep networks .
- Dynamic and Static Sharding : Offers flexibility in data and model parallelism , allowing dynamic adjustment of sharding constraints .
Efficiency and Scalability :
- Efficiently manages data flow and computation , minimizing unnecessary data replication and movement across devices .
- Designed with scalability in mind , providing a foundation for training complex models on massive datasets .
This architecture leverages JAX for high - performance numerical computing and automatic differentiation , alongside Haiku for modular and flexible deep learning model construction .
Data Flow Through the Training Process :
1. Input Preparation :
- The process begins with the preparation of input data , which typically involves tokenizing text data into numerical tokens that represent words or subwords in a vocabulary .
- Tokens are then batched and padded to ensure consistent sequence lengths within each batch , forming the initial input tensor for the model .
2. Embedding Layer :
- The input tokens are passed through an embedding layer , transforming each token into a high - dimensional vector . This layer may utilize pre - trained embeddings or learn embeddings during training .
- Positional encodings or embeddings are added to these vectors to incorporate sequence position information .
3. Transformer Layers :
- The sequence of embedding vectors is processed through multiple Transformer layers , each consisting of the following sub - layers :
a . Multi - Head Attention ( MHA ) : Each layer computes self - attention for its input , allowing each position in the sequence to attend to all positions in the previous layer ’ s output .
b . Feed - Forward Network ( FFN ) : After attention , the output for each position passes through a feed - forward network . FFNs are identical for different positions but have different parameters from layer to layer .
- Between each sub - layer , residual connections followed by layer normalization are applied . This helps in stabilizing the training of deep networks .
4. Caching and Memory :
- For autoregressive tasks , where the model generates sequences one token at a time , keys and values computed during the attention operations are cached . This mechanism allows reusing these computations in subsequent steps , reducing the computational load .
5. Output Layer :
- The output from the final Transformer layer is passed through a linear layer or a decoder , transforming the high - dimensional representations back into the vocabulary space to produce logits for each token in the sequence .
6. Loss Calculation and Backpropagation :
- The logits are compared against the true token sequences using a suitable loss function ( e . g . , cross - entropy for language modeling tasks ) . The loss quantifies the model ' s prediction accuracy.
- Based on the loss , gradients are computed for each parameter in the model using backpropagation . These gradients indicate how each parameter should be adjusted to minimize the loss .
7. Parameter Update :
- Model parameters are updated using an optimization algorithm ( e . g . , Adam , SGD ) and the computed gradients . This step adjusts the model ’ s weights to improve its predictions on the next iteration .
8. Iteration and Convergence :
- Steps 1 through 7 are repeated for multiple epochs over the training dataset until the model ' s performance on a validation set converges or begins to degrade, indicating potential overfitting.
This structured approach , combined with the Transformer architecture ' s capability to model complex dependencies in data, enables effective training of models for tasks such as language understanding, translation, and generation.
"""
2024-03-15 01:03:58 +03:00
import functools
import logging
import re
from dataclasses import dataclass
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Tuple , Union
import haiku as hk
import jax
import jax . experimental . maps
import jax . numpy as jnp
from jax import config , tree_util
from jax . experimental . shard_map import shard_map
from jax . lax import with_sharding_constraint as pjit_sharding_constraint
from jax . sharding import PartitionSpec
from jax . sharding import PartitionSpec as P
config . update ( " jax_spmd_mode " , " allow_all " )
logger = logging . getLogger ( __name__ )
rank_logger = logging . getLogger ( " rank " )
@dataclass
class QuantizedWeight8bit :
2024-03-18 23:57:49 +03:00
"""
Represents an 8 - bit quantized weight for neural network parameters .
Attributes :
weight ( jnp . array ) : The quantized weights .
scales ( jnp . array ) : The scale factors used for quantization .
"""
2024-03-15 01:03:58 +03:00
weight : jnp . array
scales : jnp . array
@property
def shape ( self ) :
return self . weight . shape
tree_util . register_pytree_node (
QuantizedWeight8bit ,
lambda qw : ( [ qw . weight , qw . scales ] , ( ) ) ,
lambda _ , children : QuantizedWeight8bit ( children [ 0 ] , children [ 1 ] ) ,
)
class TrainingState ( NamedTuple ) :
2024-03-18 23:57:49 +03:00
""" Container for the training state, encapsulating model parameters.
Attributes :
params ( hk . Params ) : The parameters of the model .
"""
2024-03-15 01:03:58 +03:00
params : hk . Params
def _match ( qs , ks ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Determines if a sequence of regex patterns ( qs ) matches any contiguous subsequence of strings ( ks ) .
This utility function is often used for matching parameter names or paths in a hierarchical structure .
2024-03-18 23:57:49 +03:00
Args :
qs ( Sequence [ str ] ) : A sequence of regex patterns to match .
ks ( Tuple [ str , . . . ] ) : A tuple of strings against which the patterns are matched .
Returns :
2024-03-19 00:53:06 +03:00
bool : True if every pattern in qs has a corresponding match in a contiguous subsequence of ks ,
otherwise False .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
# compile regexes and force complete match
qts = tuple ( map ( lambda x : re . compile ( x + " $ " ) , qs ) )
for i in range ( len ( ks ) - len ( qs ) + 1 ) :
matches = [ x . match ( y ) for x , y in zip ( qts , ks [ i : ] ) ]
if matches and all ( matches ) :
return True
return False
def with_sharding_constraint ( x , constraint ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Applies a sharding constraint to a JAX array . This function is used in SPMD programs to hint how the
data should be partitioned across devices . If a physical mesh is not available , it simply returns the
original array .
2024-03-18 23:57:49 +03:00
Args :
x ( jax . Array ) : The array to apply the sharding constraint to .
constraint ( PartitionSpec ) : The sharding constraint to apply .
Returns :
2024-03-19 00:53:06 +03:00
jax . Array : The array with the sharding constraint applied , affecting its distribution across devices
in distributed computation setups .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
if jax . experimental . maps . thread_resources . env . physical_mesh . empty :
return x
else :
return pjit_sharding_constraint ( x , constraint )
def cast_bfloat16 ( x ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Casts the input array to bfloat16 type if it is of floating - point type . This operation is often used to
reduce memory consumption and potentially increase computation speed by using lower precision .
2024-03-18 23:57:49 +03:00
Args :
x ( jax . Array ) : The input array .
Returns :
2024-03-19 00:53:06 +03:00
jax . Array : The array cast to bfloat16 if the original array was floating - point ; otherwise , the array
is returned unchanged .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
if x . dtype . kind == " f " :
return x . astype ( jnp . bfloat16 )
else :
return x
def ffn_size ( emb_size , widening_factor ) :
2024-03-18 23:57:49 +03:00
"""
Calculates the size of the feed - forward network ( FFN ) based on the embedding size and a widening factor .
The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations .
Args :
emb_size ( int ) : The size of the embeddings .
widening_factor ( float ) : The factor by which to widen the FFN relative to the embedding size .
Returns :
int : The adjusted size of the FFN .
"""
2024-03-15 01:03:58 +03:00
_ffn_size = int ( widening_factor * emb_size ) * 2 / / 3
_ffn_size = _ffn_size + ( 8 - _ffn_size ) % 8 # ensure it's a multiple of 8
logger . debug ( f " emd_size: { emb_size } adjusted ffn_size: { _ffn_size } " )
return _ffn_size
def apply_rules ( rules ) :
2024-03-18 23:57:49 +03:00
"""
Constructs a function to apply a set of sharding rules for transformer parameters .
This function is used to determine the sharding specifications for model parameters based on their roles
and positions within the model architecture .
Args :
rules ( List [ Tuple [ Sequence [ str ] , PartitionSpec ] ] ) : A list of tuples where each tuple contains a sequence
of strings representing the parameter path and the corresponding ` PartitionSpec ` to apply .
Returns :
Callable : A function that takes a parameter path and returns the appropriate ` PartitionSpec ` based
on the provided rules .
"""
2024-03-15 01:03:58 +03:00
def _apply_rules ( path , value ) :
del value # Unused.
path_list = [ str ( i . key ) . split ( " / " ) for i in path if isinstance ( i , jax . tree_util . DictKey ) ]
flattened_path = jax . tree_util . tree_flatten ( path_list ) [ 0 ]
for rule , replacement in rules :
if _match ( rule , flattened_path ) :
if isinstance ( replacement , PartitionSpec ) :
if " layer_stack " in flattened_path :
replacement = PartitionSpec ( None , * replacement )
rank_logger . debug ( f " Apply { replacement } to { flattened_path } with rule { rule } " )
return replacement
rank_logger . info ( f " { flattened_path } no matching found! " )
return None
return _apply_rules
TRANSFORMER_PARTITION_RULES = [
# attention
( ( " multi_head_attention " , " (query|key|value) " , " w " ) , P ( " data " , " model " ) ) ,
( ( " multi_head_attention " , " (query|key|value) " , " b " ) , P ( None ) ) ,
( ( " multi_head_attention " , " linear " , " w " ) , P ( " model " , " data " ) ) ,
( ( " multi_head_attention " , " linear " , " b " ) , P ( None ) ) ,
# mlp
( ( r " decoder_layer_[0-9]+ " , " linear " , " w " ) , P ( " data " , " model " ) ) ,
( ( r " decoder_layer_[0-9]+ " , " linear " , " b " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " linear_v " , " w " ) , P ( " data " , " model " ) ) ,
( ( r " decoder_layer_[0-9]+ " , " linear_v " , " b " ) , P ( None ) ) ,
(
( r " decoder_layer_[0-9]+ " , " linear_1 " , " w " ) ,
P (
" model " ,
" data " ,
) ,
) ,
( ( r " decoder_layer_[0-9]+ " , " linear_1 " , " b " ) , P ( None ) ) ,
# layer norms
( ( r " decoder_layer_[0-9]+ " , " layer_norm " , " offset " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " layer_norm " , " scale " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " layer_norm_1 " , " offset " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " layer_norm_1 " , " scale " ) , P ( None ) ) ,
# rms norms
( ( r " decoder_layer_[0-9]+ " , " rms_norm " , " scale " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " rms_norm_1 " , " scale " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " rms_norm_2 " , " scale " ) , P ( None ) ) ,
( ( r " decoder_layer_[0-9]+ " , " rms_norm_3 " , " scale " ) , P ( None ) ) ,
# router
( ( " router " , " w " ) , P ( " data " ) ) ,
# moe mlp
( ( " moe " , " linear " , " w " ) , P ( None , " data " , " model " ) ) ,
( ( " moe " , " linear " , " b " ) , P ( None ) ) ,
( ( " moe " , " linear_v " , " w " ) , P ( None , " data " , " model " ) ) ,
( ( " moe " , " linear_v " , " b " ) , P ( None ) ) ,
( ( " moe " , " linear_1 " , " w " ) , P ( None , " model " , " data " ) ) ,
( ( " moe " , " linear_1 " , " b " ) , P ( None ) ) ,
# layer norms
( ( " moe " , " layer_norm " , " offset " ) , P ( None ) ) ,
( ( " moe " , " layer_norm " , " scale " ) , P ( None ) ) ,
( ( " moe " , " layer_norm_1 " , " offset " ) , P ( None ) ) ,
( ( " moe " , " layer_norm_1 " , " scale " ) , P ( None ) ) ,
# rms norms
( ( " moe " , " rms_norm " , " scale " ) , P ( None ) ) ,
( ( " moe " , " rms_norm_1 " , " scale " ) , P ( None ) ) ,
( ( " moe " , " rms_norm_2 " , " scale " ) , P ( None ) ) ,
( ( " moe " , " rms_norm_3 " , " scale " ) , P ( None ) ) ,
]
LM_PARTITION_RULES = [
# Embedding layer.
(
( " language_model " , " positional_embeddings " ) ,
P ( None , ( " data " , " model " ) ) ,
) ,
(
( " language_model " , " in_out_embed " , " embeddings " ) ,
P ( None , ( " data " , " model " ) ) ,
) ,
# Final RMSNorm.
( ( " language_model " , " rms_norm " ) , P ( None ) ) ,
]
TOP_K = 8
class KVMemory ( NamedTuple ) :
2024-03-19 00:53:06 +03:00
"""
Represents key - value memory slots for a transformer layer , supporting efficient autoregressive decoding by caching past computed keys and values .
Attributes :
k ( Optional [ jax . Array ] ) : Cached keys , shaped as [ batch_size , sequence_len , num_kv_heads , key_size ] .
v ( Optional [ jax . Array ] ) : Cached values , shaped as [ batch_size , sequence_len , num_kv_heads , key_size ] .
step ( Optional [ jax . Array ] ) : The current decoding step , indicating how many positions have been generated , shaped as [ batch_size ] .
"""
2024-03-15 01:03:58 +03:00
k : Optional [ jax . Array ]
v : Optional [ jax . Array ]
step : Optional [ jax . Array ]
def init_layer_memories (
batch_size : int ,
sequence_len : int ,
num_kv_heads : int ,
key_size : int ,
num_layers : int ,
step : Optional [ jax . Array ] = None ,
dtype = jnp . bfloat16 ,
) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Initializes layer memories for each transformer layer , providing a mechanism for efficient sequence generation by caching keys and values .
2024-03-18 23:57:49 +03:00
Args :
2024-03-19 00:53:06 +03:00
batch_size ( int ) : The number of sequences being processed in parallel .
sequence_len ( int ) : The length of the sequences for which memory is allocated .
num_kv_heads ( int ) : The number of key - value pairs per head in the attention mechanism .
key_size ( int ) : The size of each key ( and value ) in the attention mechanism .
num_layers ( int ) : The number of transformer layers for which memory is initialized .
step ( Optional [ jax . Array ] ) : The initial decoding step for each sequence in the batch . Defaults to None , indicating no prior steps .
dtype ( Any ) : The data type for the memory arrays , typically jnp . bfloat16 for efficiency .
2024-03-18 23:57:49 +03:00
Returns :
2024-03-19 00:53:06 +03:00
List [ KVMemory ] : A list of initialized KVMemory instances for each layer in the transformer model .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
return [
KVMemory (
k = jnp . zeros ( ( batch_size , sequence_len , num_kv_heads , key_size ) , dtype = dtype ) ,
v = jnp . zeros ( ( batch_size , sequence_len , num_kv_heads , key_size ) , dtype = dtype ) ,
step = step ,
)
for _ in range ( num_layers )
]
class Memory ( NamedTuple ) :
2024-03-19 00:53:06 +03:00
"""
A named tuple representing the complete memory state of a transformer model , encapsulating key - value memory slots for all layers .
Attributes :
layers ( List [ KVMemory ] ) : A list of KVMemory instances , one for each layer in the transformer model .
"""
2024-03-15 01:03:58 +03:00
# Self-attention key/value cache.
layers : List [ KVMemory ]
class Router ( hk . Module ) :
2024-03-18 23:57:49 +03:00
"""
A module for routing inputs to experts in a Mixture of Experts ( MoE ) layer .
Attributes :
num_selected_experts ( int ) : Number of experts to select for each input .
data_axis ( str | Tuple [ str , . . . ] ) : The name ( s ) of the data axis for sharding .
model_axis ( str | Tuple [ str , . . . ] ) : The name ( s ) of the model axis for sharding .
shard_activations ( bool ) : If True , shard activations according to the data and model axes .
mesh ( Any ) : The SPMD mesh for parallel computation .
name ( str ) : The name of the module .
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
num_selected_experts : int ,
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data " ,
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model " ,
shard_activations : bool = False ,
mesh : Any = None ,
name : str = " router " ,
) :
2024-03-19 00:53:06 +03:00
"""
Initializes a router for directing inputs to experts in a Mixture of Experts ( MoE ) layer .
Args :
num_selected_experts ( int ) : The number of experts to select for each input .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : The axis names over which data is sharded .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : The axis names over which model parameters are sharded .
shard_activations ( bool ) : Indicates whether to shard activations according to the data and model axes .
mesh ( Any ) : The SPMD mesh object for parallel computation .
name ( str ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
super ( ) . __init__ ( name )
self . shard_activations = shard_activations
self . data_axis = data_axis
self . model_axis = model_axis
self . mesh = mesh
self . num_selected_experts = num_selected_experts
def compute_routing_prob (
self , inputs : jax . Array , padding_mask : Optional [ jax . Array ] , num_experts : int
) :
2024-03-19 00:53:06 +03:00
"""
Computes the routing probabilities for each input to be directed to each expert .
This internal method calculates the logits that determine how inputs are distributed among the experts ,
based on learned criteria .
Args :
inputs ( jax . Array ) : Input data to be routed .
padding_mask ( Optional [ jax . Array ] ) : Mask indicating which inputs are padding and should not be considered .
num_experts ( int ) : The total number of experts available .
Returns :
A tuple containing the routing probabilities , logits , and a dummy placeholder for compatibility .
"""
2024-03-15 01:03:58 +03:00
return self . _compute_routing_prob ( inputs , padding_mask , num_experts )
@hk.transparent
def _compute_routing_prob (
self ,
inputs : jax . Array ,
padding_mask : Optional [ jax . Array ] ,
num_experts : int ,
) :
2024-03-18 23:57:49 +03:00
"""
Computes the routing probabilities for directing inputs to the appropriate experts .
Args :
inputs ( jax . Array ) : Input data to be routed , shaped as [ batch_size , . . . , input_dim ] .
padding_mask ( Optional [ jax . Array ] ) : An optional mask indicating padded elements in the input ,
shaped as [ batch_size , seq_length ] , where padded positions are False .
num_experts ( int ) : The total number of experts available for routing .
Returns :
A tuple containing routing probabilities , routing logits , and a dummy value for compatibility ,
shaped as ( [ batch_size , seq_length , num_experts ] , [ batch_size , seq_length , num_experts ] , int ) .
"""
2024-03-15 01:03:58 +03:00
# Using fp32 for the routing prob computation.
inputs = jax . lax . convert_element_type ( inputs , jnp . float32 )
# [batch_size, seq_len, num_experts]
routing_logits = self . _router_weights ( inputs , num_experts , sharding = P ( " data " ) )
assert routing_logits . dtype == jnp . float32
routing_probs = jax . nn . softmax ( routing_logits )
if padding_mask is not None :
routing_probs * = padding_mask
return routing_probs , routing_logits , 0
@hk.transparent
def _router_weights (
self ,
x : jax . Array ,
num_experts : int ,
sharding : Optional [ P ] = None ,
) :
fprop_dtype = x . dtype
if not x . shape :
raise ValueError ( " Input must not be scalar. " )
input_size = self . input_size = x . shape [ - 1 ]
w = hk . get_parameter (
" w " , [ input_size , num_experts ] , jnp . float32 , init = hk . initializers . Constant ( 0 )
)
if sharding :
w = with_sharding_constraint ( w , sharding )
out = jnp . dot ( x , w . astype ( fprop_dtype ) )
return out
class MoELayer ( hk . Module ) :
2024-03-18 23:57:49 +03:00
"""
A module implementing a Mixture of Experts ( MoE ) layer .
Attributes :
num_experts ( int ) : The number of experts in the MoE layer .
layer_fn ( Callable ) : The function to be applied by each expert .
router ( Router ) : The router that routes inputs to experts .
mesh ( Any ) : The SPMD mesh for parallel computation .
shard_activations ( bool ) : If True , shard activations across data and model axes .
data_axis ( str | Tuple [ str , . . . ] ) : The name ( s ) of the data axis for sharding .
model_axis ( str | Tuple [ str , . . . ] ) : The name ( s ) of the model axis for sharding .
name ( Optional [ str ] ) : The name of the module .
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
num_experts : int ,
layer_fn : Callable ,
router : Router ,
mesh : Any = None ,
shard_activations : bool = False ,
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data " ,
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model " ,
name : Optional [ str ] = " moe " ,
) :
2024-03-19 00:53:06 +03:00
"""
Initializes a Mixture of Experts layer with specified configuration .
Args :
num_experts ( int ) : The total number of experts in the MoE layer .
layer_fn ( Callable ) : The function defining the computation performed by each expert .
router ( Router ) : The router that directs inputs to selected experts .
mesh ( Any ) : The optional SPMD mesh for parallel computation .
shard_activations ( bool ) : Whether to shard activations across distributed resources .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Specifies how data is sharded for distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Specifies how model parameters are sharded .
name ( Optional [ str ] ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
super ( ) . __init__ ( name )
self . num_experts = num_experts
self . layer_fn = layer_fn
self . router = router
self . mesh = mesh
self . shard_activations = shard_activations
self . data_axis = data_axis
self . model_axis = model_axis
@hk.transparent
def _inference_call ( self , inputs : jax . Array , padding_mask : Optional [ jax . Array ] = None ) :
2024-03-18 23:57:49 +03:00
"""
Handles the inference call to the MoE layer , distributing inputs to selected experts based on routing .
Args :
inputs ( jax . Array ) : Input data to be processed , shaped as [ batch_size , seq_length , input_dim ] .
padding_mask ( Optional [ jax . Array ] ) : An optional mask for the inputs , where False indicates
positions that should not be processed ( e . g . , padding ) , shaped as [ batch_size , seq_length ] .
Returns :
jax . Array : The processed outputs after passing through the selected experts , shaped as
[ batch_size , seq_length , output_dim ] .
"""
2024-03-15 01:03:58 +03:00
routing_probs , _ , _ = self . router . compute_routing_prob (
inputs , padding_mask , self . num_experts
)
expert_gate , expert_index = jax . lax . top_k ( routing_probs , k = self . router . num_selected_experts )
tmp = jnp . reshape ( inputs , ( inputs . shape [ 0 ] * inputs . shape [ 1 ] , inputs . shape [ 2 ] ) )
broad_inputs = jnp . tile ( tmp [ : , jnp . newaxis , : ] , ( 1 , self . router . num_selected_experts , 1 ) )
broad_inputs = jnp . reshape (
broad_inputs , ( broad_inputs . shape [ 0 ] * broad_inputs . shape [ 1 ] , broad_inputs . shape [ 2 ] )
)
init_fn , _ = hk . transform ( self . layer_fn )
vmapped_init_fn = jax . vmap ( init_fn , in_axes = 0 , out_axes = 0 )
lifted_init_fn = hk . experimental . transparent_lift ( vmapped_init_fn )
# Fetch the vmapped params of the DenseBlock.
params = lifted_init_fn (
jax . random . split ( jax . random . PRNGKey ( 1 ) , self . num_experts ) ,
jnp . zeros ( ( self . num_experts , 1 , 1 , inputs . shape [ - 1 ] ) ) ,
)
# Index and prob are in the shape [m, 2] indicating which token assigned to which experts.
# b: num_expert
# m: token or sequence dim
# k: input embed dim
# n: output embed dim
# e: the number of experts chosen for each token
@functools.partial (
shard_map ,
mesh = self . mesh ,
in_specs = (
P ( self . data_axis , None ) ,
P ( None , None , self . model_axis ) ,
P ( None , None , self . model_axis ) ,
P ( None ) ,
P ( None ) ,
) ,
out_specs = P ( self . data_axis , self . model_axis ) ,
check_rep = False ,
)
def moe_slow_matmul1 ( input , weight , scales , index , prob ) :
weight = weight * scales
one_hot_indices = jax . nn . one_hot ( index . reshape ( - 1 ) , 8 , axis = 0 )
all_expert_output = jnp . einsum ( " mk,bkn->bmn " , input , weight )
output = jnp . einsum ( " bm,bmn->mn " , one_hot_indices , all_expert_output )
return output
@functools.partial (
shard_map ,
mesh = self . mesh ,
in_specs = (
P ( self . data_axis , self . model_axis ) ,
P ( None , self . model_axis , None ) ,
P ( None , self . model_axis , None ) ,
P ( None ) ,
P ( None ) ,
) ,
out_specs = P ( self . data_axis , None ) ,
check_rep = False ,
)
def moe_slow_matmul2 ( input , weight , scales , index , prob ) :
weight = weight * scales
one_hot_indices = jax . nn . one_hot ( index . reshape ( - 1 ) , 8 , axis = 0 )
all_expert_output = jnp . einsum ( " mk,bkn->bmn " , input , weight )
output = jnp . einsum ( " bm,bmn->mn " , one_hot_indices , all_expert_output )
return jax . lax . psum ( output , axis_name = " model " )
if hasattr ( params [ " linear " ] [ " w " ] , " scales " ) :
x = moe_slow_matmul1 (
broad_inputs ,
params [ " linear_v " ] [ " w " ] . weight ,
params [ " linear_v " ] [ " w " ] . scales ,
expert_index ,
expert_gate ,
)
y = moe_slow_matmul1 (
broad_inputs ,
params [ " linear " ] [ " w " ] . weight ,
params [ " linear " ] [ " w " ] . scales ,
expert_index ,
expert_gate ,
)
y = jax . nn . gelu ( y )
out = moe_slow_matmul2 (
x * y ,
params [ " linear_1 " ] [ " w " ] . weight ,
params [ " linear_1 " ] [ " w " ] . scales ,
expert_index ,
expert_gate ,
)
out = jnp . reshape (
out ,
[
inputs . shape [ 0 ] ,
inputs . shape [ 1 ] ,
self . router . num_selected_experts ,
out . shape [ - 1 ] ,
] ,
)
out = expert_gate [ : , : , : , None ] . astype ( jnp . bfloat16 ) * out
out = jnp . sum ( out , axis = 2 )
out = out . astype ( jnp . bfloat16 )
else :
# This is only here so that we can construct a valid init_fn with this code.
return inputs
return out
def __call__ ( self , inputs : jax . Array , padding_mask : jax . Array ) :
return self . _inference_call ( inputs )
class MHAOutput ( NamedTuple ) :
2024-03-18 23:57:49 +03:00
"""
Represents the output of the Multi - Head Attention ( MHA ) operation .
Attributes :
embeddings ( jax . Array ) : The output embeddings from the MHA layer .
memory ( Any ) : The updated memory state post - attention operation .
"""
2024-03-15 01:03:58 +03:00
embeddings : jax . Array
memory : Any
class DecoderOutput ( NamedTuple ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Encapsulates the output from a decoder layer within the transformer model , including the transformed embeddings and any updated memory state .
2024-03-18 23:57:49 +03:00
Attributes :
2024-03-19 00:53:06 +03:00
embeddings ( jax . Array ) : The embeddings produced by the decoder layer , shaped as [ batch_size , seq_length , embedding_dim ] .
memory ( Any ) : The updated memory state after processing by the decoder layer , useful for autoregressive decoding tasks .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
embeddings : jax . Array
memory : Any
class TransformerOutput ( NamedTuple ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Represents the final output from the transformer model , including the final set of embeddings and any memory states that have been updated through the model ' s layers.
2024-03-18 23:57:49 +03:00
Attributes :
2024-03-19 00:53:06 +03:00
embeddings ( jax . Array ) : The final output embeddings from the transformer , shaped as [ batch_size , seq_length , embedding_dim ] .
memory ( Any ) : The final memory state of the model after all transformer layers have been applied .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
embeddings : jax . Array
memory : Any
@dataclass
class TransformerConfig :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Configuration class for setting up a Transformer model ' s architecture and its specific parameters.
This class defines key architectural features of the transformer , including the size of embeddings ,
the dimensionality of keys and values in the attention mechanism , the number of layers , and more .
It also includes configurations for advanced features like Mixture of Experts ( MoE ) and activation sharding .
2024-03-18 23:57:49 +03:00
Attributes :
2024-03-19 00:53:06 +03:00
emb_size ( int ) : The size of the embedding vectors .
key_size ( int ) : The size of the key ( and query ) vectors in the attention mechanism .
num_q_heads ( int ) : The number of heads in the query part of the multi - head attention mechanism .
num_kv_heads ( int ) : The number of heads for keys and values in the multi - head attention .
num_layers ( int ) : The total number of layers in the transformer model .
vocab_size ( int ) : The size of the vocabulary that the model can understand .
widening_factor ( float ) : The factor by which the dimensionality of the feed - forward networks is widened relative to the embedding size .
attn_output_multiplier ( float ) : A scaling factor applied to the output of the attention mechanism , for controlling its magnitude .
shard_activations ( bool ) : Whether to shard activations across devices for parallel processing .
num_experts ( int ) : The number of experts in the Mixture of Experts ( MoE ) layer , if used .
num_selected_experts ( int ) : The number of experts selected for each input in the MoE layer .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Specifies the axis names over which data is sharded for distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Specifies the axis names over which model parameters are sharded for distributed computation .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
emb_size : int
key_size : int
num_q_heads : int
num_kv_heads : int
num_layers : int
vocab_size : int = 128 * 1024
widening_factor : float = 4.0
attn_output_multiplier : float = 1.0
name : Optional [ str ] = None
num_experts : int = - 1
capacity_factor : float = 1.0
num_selected_experts : int = 1
init_scale : float = 1.0
shard_activations : bool = False
# Used for activation sharding.
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data "
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model "
def __post_init__ ( self ) :
if isinstance ( self . data_axis , list ) :
self . data_axis = tuple ( self . data_axis )
if isinstance ( self . model_axis , list ) :
self . model_axis = tuple ( self . model_axis )
def partition_rules ( self ) :
return TRANSFORMER_PARTITION_RULES
def make ( self , mesh = None ) - > " Transformer " :
data_axis = tuple ( self . data_axis ) if isinstance ( self . data_axis , list ) else self . data_axis
model_axis = (
tuple ( self . model_axis ) if isinstance ( self . model_axis , list ) else self . model_axis
)
return Transformer (
num_q_heads = self . num_q_heads ,
num_kv_heads = self . num_kv_heads ,
widening_factor = self . widening_factor ,
key_size = self . key_size ,
init_scale = self . init_scale ,
mesh = mesh ,
attn_output_multiplier = self . attn_output_multiplier ,
shard_activations = self . shard_activations ,
num_layers = self . num_layers ,
num_experts = self . num_experts ,
num_selected_experts = self . num_selected_experts ,
data_axis = data_axis ,
model_axis = model_axis ,
)
def get_memory_sharding ( self ) :
return Memory (
layers = [
KVMemory (
k = P ( self . data_axis , self . model_axis ) ,
v = P ( self . data_axis , self . model_axis ) ,
step = P ( self . data_axis ) ,
)
for _ in range ( self . num_layers )
] ,
)
def hk_rms_norm (
x : jax . Array ,
fixed_scale = False ,
sharding = P ( None ) ,
) - > jax . Array :
""" Applies a unique LayerNorm to x with default settings. """
ln = RMSNorm ( axis = - 1 , create_scale = not fixed_scale , sharding = sharding )
return ln ( x )
def make_attention_mask (
query_input : jax . Array ,
key_input : jax . Array ,
pairwise_fn : Callable [ . . . , Any ] = jnp . multiply ,
dtype : Any = jnp . bfloat16 ,
) :
""" Mask-making helper for attention weights.
2024-03-19 00:53:06 +03:00
Creates an attention mask to specify which tokens in the key sequences can be attended to by each token in the query sequences .
This utility is used in attention mechanisms to control the visibility of tokens , for purposes such as preventing future tokens from being attended to in autoregressive models .
2024-03-15 01:03:58 +03:00
In case of 1 d inputs ( i . e . , ` [ batch . . . , len_q ] ` , ` [ batch . . . , len_kv ] ` , the
attention weights will be ` [ batch . . . , heads , len_q , len_kv ] ` and this
function will produce ` [ batch . . . , 1 , len_q , len_kv ] ` .
Args :
query_input : a batched , flat input of query_length size
key_input : a batched , flat input of key_length size
pairwise_fn : broadcasting elementwise comparison function
dtype : mask return dtype
Returns :
A ` [ batch . . . , 1 , len_q , len_kv ] ` shaped mask for 1 d attention .
"""
mask = pairwise_fn ( jnp . expand_dims ( query_input , axis = - 1 ) , jnp . expand_dims ( key_input , axis = - 2 ) )
mask = jnp . expand_dims ( mask , axis = - 3 )
return mask . astype ( dtype )
class Linear ( hk . Linear ) :
2024-03-18 23:57:49 +03:00
"""
Extends Haiku ' s Linear layer with optional sharding for use in distributed settings.
This class allows specifying a ` PartitionSpec ` to shard the linear layer ' s weights across devices,
which can be beneficial in large - scale models processed over multiple devices or nodes .
Args :
output_size ( int ) : The size of the output dimension .
with_bias ( bool , optional ) : Whether to include a bias term . Defaults to True .
sharding ( Optional [ P ] , optional ) : The sharding specification for distributing the layer ' s parameters.
mesh ( Any , optional ) : The SPMD mesh for parallel computation . Defaults to None .
name ( Optional [ str ] , optional ) : An optional name for this module . Defaults to None .
shard_axis ( int , optional ) : The axis along which to shard the input data . Defaults to 0.
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
output_size : int ,
with_bias : bool = True ,
sharding : Optional [ P ] = None ,
mesh : Any = None ,
name : Optional [ str ] = None ,
shard_axis : int = 0 ,
) :
super ( ) . __init__ (
output_size = output_size ,
with_bias = with_bias ,
name = name ,
)
self . sharding = sharding
self . mesh = mesh
self . shard_axis = shard_axis
def __call__ (
self ,
inputs : jax . Array ,
) - > jax . Array :
2024-03-19 00:53:06 +03:00
"""
Computes a linear transform of the input data .
This method computes the matrix multiplication between the inputs and the layer ' s weight matrix, optionally adding a bias term.
Args :
inputs ( jax . Array ) : The input tensor to be transformed , shaped as [ batch_size , . . . , input_features ] .
Returns :
jax . Array : The transformed tensor , shaped as [ batch_size , . . . , output_features ] .
"""
2024-03-15 01:03:58 +03:00
fprop_dtype = inputs . dtype
if not inputs . shape :
raise ValueError ( " Input must not be scalar. " )
input_size = self . input_size = inputs . shape [ - 1 ]
output_size = self . output_size
w = hk . get_parameter (
" w " , [ input_size , output_size ] , jnp . float32 , init = hk . initializers . Constant ( 0 )
)
if hasattr ( w , " scales " ) :
shape = inputs . shape
inputs = jnp . reshape ( inputs , ( - 1 , shape [ - 1 ] ) )
@functools.partial (
shard_map ,
mesh = self . mesh ,
in_specs = ( self . sharding , self . sharding ) ,
out_specs = self . sharding ,
check_rep = False ,
)
def mul ( w , s ) :
return w . astype ( s . dtype ) * s
w = mul ( w . weight , w . scales )
out = jnp . dot ( inputs , w . astype ( fprop_dtype ) )
if self . with_bias :
b = hk . get_parameter (
" b " , [ self . output_size ] , jnp . float32 , init = hk . initializers . Constant ( 0 )
)
b = jnp . broadcast_to ( b , out . shape )
out = out + b . astype ( fprop_dtype )
return out
class RMSNorm ( hk . RMSNorm ) :
2024-03-18 23:57:49 +03:00
"""
Implements Root Mean Square Layer Normalization .
2024-03-15 01:03:58 +03:00
2024-03-18 23:57:49 +03:00
This variant of layer normalization scales inputs by the root mean square of their elements , optionally
including a learnable scaling factor . It supports specifying a ` PartitionSpec ` for sharding the scale
parameters across devices in distributed settings .
Args :
axis ( Union [ int , Sequence [ int ] , slice ] ) : The dimensions to normalize over .
eps ( float , optional ) : A small constant added to the denominator to improve numerical stability .
Defaults to 1e-5 .
name ( Optional [ str ] , optional ) : An optional name for this module . Defaults to None .
create_scale ( bool , optional ) : Whether to include a learnable scaling factor . Defaults to True .
sharding ( Optional [ P ] , optional ) : The sharding specification for the scale parameter . Defaults to None .
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
axis : Union [ int , Sequence [ int ] , slice ] ,
eps : float = 1e-5 ,
name : Optional [ str ] = None ,
create_scale : bool = True ,
sharding : Optional [ P ] = None ,
) :
super ( ) . __init__ ( axis , eps , create_scale = create_scale , name = name )
self . sharding = sharding
def __call__ ( self , inputs : jax . Array ) :
2024-03-19 00:53:06 +03:00
"""
Applies RMS normalization to the input tensor .
This method normalizes the inputs by their root mean square value , optionally scaling the result by a learnable parameter to adjust the representation scale .
Args :
inputs ( jax . Array ) : The input tensor to be normalized , shaped as [ batch_size , . . . , features ] .
Returns :
jax . Array : The RMS - normalized tensor , maintaining the input shape .
"""
2024-03-15 01:03:58 +03:00
fprop_dtype = inputs . dtype
param_shape = ( inputs . shape [ - 1 ] , )
if self . create_scale :
scale = hk . get_parameter (
" scale " ,
param_shape ,
dtype = jnp . float32 ,
init = hk . initializers . Constant ( 0 ) ,
)
if self . sharding :
scale = with_sharding_constraint ( scale , self . sharding )
scale = jnp . broadcast_to ( scale . astype ( jnp . float32 ) , inputs . shape )
else :
scale = 1.0
inputs = inputs . astype ( jnp . float32 )
scale = scale . astype ( jnp . float32 )
mean_squared = jnp . mean ( jnp . square ( inputs ) , axis = [ - 1 ] , keepdims = True )
mean_squared = jnp . broadcast_to ( mean_squared , inputs . shape )
normed_inputs = inputs * jax . lax . rsqrt ( mean_squared + self . eps )
outputs = scale * normed_inputs
return outputs . astype ( fprop_dtype )
def rotate_half (
x : jax . Array ,
) - > jax . Array :
""" Obtain the rotated counterpart of each feature """
x1 , x2 = jnp . split ( x , 2 , axis = - 1 )
return jnp . concatenate ( ( - x2 , x1 ) , axis = - 1 )
class RotaryEmbedding ( hk . Module ) :
2024-03-18 23:57:49 +03:00
"""
Implements Rotary Position Embedding ( RoPE ) to the input sequence tensor ,
2024-03-15 01:03:58 +03:00
as described in https : / / arxiv . org / abs / 2104.09864 .
2024-03-18 23:57:49 +03:00
RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their
position in the sequence . This approach is designed to preserve the relative positional information across
different sequence lengths and tasks .
Args :
dim ( int ) : The dimensionality of the embeddings to be rotated , must be even .
name ( Optional [ str ] , optional ) : An optional name for this module . Defaults to None .
base_exponent ( int , optional ) : The base of the exponent used to calculate rotary frequencies .
Defaults to 10000.
2024-03-15 01:03:58 +03:00
"""
def __init__ (
self ,
dim : int ,
name : Optional [ str ] = None ,
base_exponent : int = 10000 ,
) :
super ( ) . __init__ ( name )
self . dim = dim
self . base_exponent = base_exponent
assert self . dim % 2 == 0
def __call__ (
self ,
x : jax . Array ,
seq_dim : int ,
offset : jax . Array ,
const_position : Optional [ int ] = None ,
t : Optional [ jax . Array ] = None ,
) - > jax . Array :
2024-03-19 00:53:06 +03:00
"""
Applies Rotary Position Embedding ( RoPE ) to the input embeddings .
This method dynamically encodes positional information by applying a rotation based on the position in the sequence , enhancing the model ' s ability to capture sequential dependencies.
Args :
x ( jax . Array ) : Input embeddings to apply RoPE to , shaped as [ batch_size , seq_length , embedding_dim ] .
seq_dim ( int ) : The dimension index of the sequence length in the input tensor .
offset ( jax . Array ) : The offset to apply to the position indices , useful for continuation of sequences across batches .
const_position ( Optional [ int ] ) : A constant position value to use for all positions , instead of a range .
t ( Optional [ jax . Array ] ) : Explicit tensor of position values , shaped as [ seq_length ] .
Returns :
jax . Array : The input embeddings with RoPE applied , maintaining the input shape .
"""
2024-03-15 01:03:58 +03:00
fprop_dtype = x . dtype
# Compute the per-dimension frequencies
exponents = jnp . arange ( 0 , self . dim , 2 , dtype = jnp . float32 )
inv_freq = jnp . asarray (
1.0 / ( self . base_exponent * * ( exponents / self . dim ) ) , dtype = jnp . float32
)
if jnp . shape ( offset ) == ( ) :
# Offset can be a scalar or one offset per batch element.
offset = jnp . expand_dims ( offset , 0 )
# Compute the per element phase (to pass into sin and cos)
if const_position :
t = const_position * jnp . ones (
(
1 ,
x . shape [ seq_dim ] ,
) ,
dtype = jnp . float32 ,
)
elif t is None :
t = jnp . arange ( x . shape [ seq_dim ] , dtype = jnp . float32 ) + jnp . expand_dims ( offset , - 1 )
phase = jnp . einsum ( " bi,j->bij " , t , inv_freq )
phase = jnp . tile ( phase , reps = ( 1 , 2 ) ) [ : , : , None , : ]
x = x * jnp . cos ( phase ) + rotate_half ( x ) * jnp . sin ( phase )
x = x . astype ( fprop_dtype )
return x
class MultiHeadAttention ( hk . Module ) :
2024-03-19 00:53:06 +03:00
"""
Implements the Multi - Head Attention mechanism , a key component of Transformer architectures .
This module allows each token in the input sequence to attend to all tokens in the key and value sequences , with multiple " heads " learning different attention patterns .
Attributes :
num_q_heads ( int ) : The number of heads for the queries .
num_kv_heads ( int ) : The number of heads for the keys and values .
key_size ( int ) : The size of the key vectors .
with_bias ( bool ) : Whether to include bias terms in the linear transformations .
value_size ( Optional [ int ] ) : The size of the value vectors , defaults to the key size if not specified .
model_size ( Optional [ int ] ) : The size of the output dimension from the attention mechanism , defaults to the total size of all heads if not specified .
attn_output_multiplier ( float ) : A scaling factor applied to the output of the attention mechanism .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : The axis names over which data is sharded for distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : The axis names over which model parameters are sharded for distributed computation .
name ( Optional [ str ] ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
num_q_heads : int ,
num_kv_heads : int ,
key_size : int ,
* ,
with_bias : bool = True ,
value_size : Optional [ int ] = None ,
model_size : Optional [ int ] = None ,
attn_output_multiplier : 1.0 ,
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data " ,
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model " ,
name : Optional [ str ] = None ,
) :
2024-03-19 00:53:06 +03:00
"""
Initializes the Multi - Head Attention module with the provided configuration parameters .
Args :
num_q_heads ( int ) : Number of query heads for the multi - head attention mechanism .
num_kv_heads ( int ) : Number of key / value heads for the multi - head attention mechanism .
key_size ( int ) : Dimensionality of key vectors in the attention mechanism .
with_bias ( bool ) : Whether to include a bias term in the attention weight computation .
value_size ( Optional [ int ] ) : Dimensionality of value vectors , defaults to key size if not specified .
model_size ( Optional [ int ] ) : Overall size of the model ' s output dimension.
attn_output_multiplier ( float ) : Multiplier for scaling the output of the attention mechanism .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Data sharding axis names for distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Model parameter sharding axis names for distributed computation .
name ( Optional [ str ] ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
super ( ) . __init__ ( name = name )
self . num_q_heads = num_q_heads
self . num_kv_heads = num_kv_heads
self . key_size = key_size
self . value_size = value_size or key_size
self . model_size = model_size or key_size * num_q_heads
self . data_axis = data_axis
self . model_axis = model_axis
self . attn_output_multiplier = attn_output_multiplier
self . with_bias = with_bias
def __call__ (
self ,
query : jax . Array ,
key : Optional [ jax . Array ] ,
value : Optional [ jax . Array ] ,
mask : Optional [ jax . Array ] = None ,
kv_memory : Optional [ KVMemory ] = None ,
mesh : Any = None ,
) - > MHAOutput :
2024-03-18 23:57:49 +03:00
"""
Computes the multi - head attention over the input queries , keys , and values .
Args :
query ( jax . Array ) : Query vectors , shaped as [ batch_size , seq_length , model_dim ] .
key ( Optional [ jax . Array ] ) : Key vectors . If None , uses query as key .
value ( Optional [ jax . Array ] ) : Value vectors . If None , uses query as value .
mask ( Optional [ jax . Array ] ) : An optional mask to prevent attention to certain positions ,
shaped as [ batch_size , 1 , seq_length , seq_length ] .
kv_memory ( Optional [ KVMemory ] ) : Optional memory for keys and values to support efficient
autoregressive decoding .
mesh ( Any ) : The SPMD mesh for parallel computation , if applicable .
Returns :
MHAOutput : A named tuple containing the output embeddings and updated memory .
"""
2024-03-15 01:03:58 +03:00
# In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B].
sequence_length = query . shape [ 1 ]
projection = self . _linear_projection
use_memory = False
if kv_memory is not None :
if kv_memory . k is None :
assert kv_memory . v is None
assert key is not None
assert value is not None
else :
assert kv_memory . v is not None
use_memory = True
else :
assert key is not None
assert value is not None
# Check that the keys and values have consistent batch size and sequence length.
if not use_memory :
assert key . shape [ : 2 ] == value . shape [ : 2 ] , f " key/value shape: { key . shape } / { value . shape } "
if mask is not None :
assert mask . ndim == 4
assert mask . shape [ 0 ] in {
1 ,
query . shape [ 0 ] ,
} , f " mask/query shape: { mask . shape } / { query . shape } "
if not use_memory :
assert key . shape [ 0 ] in {
1 ,
query . shape [ 0 ] ,
} , f " key/query shape: { key . shape } / { query . shape } "
assert mask . shape [ 1 ] == 1
assert mask . shape [ 2 ] in {
1 ,
query . shape [ 1 ] ,
} , f " mask/query shape: { mask . shape } / { query . shape } "
if not use_memory :
assert mask . shape [ 3 ] in {
1 ,
key . shape [ 1 ] ,
} , f " mask/query shape: { mask . shape } / { key . shape } "
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
assert self . num_q_heads % self . num_kv_heads == 0
query_heads = projection (
query ,
self . key_size ,
self . num_q_heads ,
name = " query " ,
sharding = P ( " data " , " model " ) ,
mesh = mesh ,
) # [B, T', H, Q=K]
new_memory = None
key_heads = projection (
key ,
self . key_size ,
self . num_kv_heads ,
name = " key " ,
sharding = P ( " data " , " model " ) ,
mesh = mesh ,
) # [B, T, H, K]
value_heads = projection (
value ,
self . value_size ,
self . num_kv_heads ,
name = " value " ,
sharding = P ( " data " , " model " ) ,
mesh = mesh ,
) # [B, T, H, V]
rotate = RotaryEmbedding ( dim = self . key_size , base_exponent = int ( 1e4 ) )
key_heads = rotate ( key_heads , seq_dim = 1 , offset = ( kv_memory . step if kv_memory else 0 ) )
query_heads = rotate ( query_heads , seq_dim = 1 , offset = ( kv_memory . step if kv_memory else 0 ) )
@functools.partial ( jax . vmap )
def update_into ( mem , start , update ) :
return jax . lax . dynamic_update_slice_in_dim ( mem , update , start , axis = 0 )
if kv_memory :
if mesh is not None :
@functools.partial (
shard_map ,
mesh = mesh ,
in_specs = (
P ( " data " , None , " model " ) ,
P ( " data " ) ,
P ( " data " , None , " model " ) ,
) ,
out_specs = P ( " data " , None , " model " ) ,
check_rep = False ,
)
def update_into_shmap ( mems , starts , updates ) :
return update_into ( mems , starts , updates )
key_heads = update_into_shmap ( kv_memory . k , kv_memory . step , key_heads )
value_heads = update_into_shmap ( kv_memory . v , kv_memory . step , value_heads )
else :
key_heads = update_into ( kv_memory . k , kv_memory . step , key_heads )
value_heads = update_into ( kv_memory . v , kv_memory . step , value_heads )
new_step = kv_memory . step + sequence_length
memory_mask = jnp . arange ( kv_memory . k . shape [ 1 ] ) < new_step [ : , None ]
memory_mask = memory_mask [ : , None , None , : ] # [B, H, T, T]
if mask is not None :
mask = memory_mask * mask
else :
mask = memory_mask
new_memory = KVMemory (
k = key_heads ,
v = value_heads ,
step = new_step ,
)
# Add separate dimension for grouped query heads.
query_heads = with_sharding_constraint ( query_heads , P ( self . data_axis , None , " model " , None ) )
key_heads = with_sharding_constraint ( key_heads , P ( self . data_axis , None , " model " , None ) )
value_heads = with_sharding_constraint ( value_heads , P ( self . data_axis , None , " model " , None ) )
b , t , h , d = query_heads . shape
_ , _ , kv_h , _ = key_heads . shape
assert h % kv_h == 0 , f " query_heads { h } must be a multiple of kv_heads { kv_h } "
query_heads = jnp . reshape ( query_heads , ( b , t , kv_h , h / / kv_h , d ) )
query_heads = with_sharding_constraint (
query_heads , P ( self . data_axis , None , " model " , None , None )
)
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp . einsum ( " ...thHd,...Thd->...hHtT " , query_heads , key_heads ) . astype (
jnp . float32
)
attn_logits * = self . attn_output_multiplier
max_attn_val = jnp . array ( 30.0 , dtype = attn_logits . dtype )
attn_logits = max_attn_val * jnp . tanh ( attn_logits / max_attn_val )
mask = mask [ : , : , None , : , : ]
if mask is not None :
if mask . ndim != attn_logits . ndim :
raise ValueError (
f " Mask dimensionality { mask . ndim } must match logits dimensionality "
f " { attn_logits . ndim } for { mask . shape } / { attn_logits . shape } . "
)
attn_logits = jnp . where ( mask , attn_logits , - 1e30 )
attn_weights = jax . nn . softmax ( attn_logits ) . astype ( query . dtype ) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp . einsum ( " ...hHtT,...Thd->...thHd " , attn_weights , value_heads )
attn = with_sharding_constraint ( attn , P ( self . data_axis , None , " model " , None , None ) )
leading_dims = attn . shape [ : 2 ]
attn = jnp . reshape ( attn , ( * leading_dims , - 1 ) ) # [T', H*V]
attn = with_sharding_constraint ( attn , P ( self . data_axis , None , " model " ) )
# Apply another projection to get the final embeddings.
final_projection = Linear (
self . model_size ,
with_bias = False ,
sharding = P ( " model " , " data " ) ,
mesh = mesh ,
)
return MHAOutput ( final_projection ( attn ) , new_memory )
@hk.transparent
def _linear_projection (
self ,
x : jax . Array ,
head_size : int ,
num_heads : int ,
sharding : Optional [ P ] = None ,
name : Optional [ str ] = None ,
mesh : Any = None ,
) - > jax . Array :
2024-03-19 00:53:06 +03:00
"""
Projects input embeddings into multiple head spaces for queries , keys , or values .
This internal method creates separate embeddings for each attention head by applying a linear transformation ,
allowing the multi - head attention mechanism to explore different subspace relationships in the data .
Args :
x ( jax . Array ) : Input tensor to project , shaped as [ batch_size , seq_length , embedding_dim ] .
head_size ( int ) : The dimensionality of each head ' s subspace.
num_heads ( int ) : The number of heads to project into .
name ( Optional [ str ] ) : A name for the operation , distinguishing between queries , keys , and values .
sharding ( Optional [ PartitionSpec ] ) : The sharding specification for distributing computation across devices .
Returns :
jax . Array : Projected embeddings for multiple heads , shaped as [ batch_size , seq_length , num_heads , head_size ] .
"""
2024-03-15 01:03:58 +03:00
y = Linear (
num_heads * head_size ,
with_bias = False ,
name = name ,
sharding = sharding ,
mesh = mesh ,
) ( x )
* leading_dims , _ = x . shape
return y . reshape ( ( * leading_dims , num_heads , head_size ) )
@dataclass
class MHABlock ( hk . Module ) :
2024-03-19 00:53:06 +03:00
"""
A specialized module encapsulating the Multi - Head Attention ( MHA ) operation within a transformer model .
This module orchestrates the application of MHA , including the computation of queries , keys , and values , and the subsequent attention and aggregation operations .
Attributes :
num_q_heads ( int ) : Number of heads for the query part of the MHA .
num_kv_heads ( int ) : Number of heads for the key and value parts of the MHA .
key_size ( int ) : Size of the keys ( and queries ) in the attention mechanism .
attn_output_multiplier ( float ) : Scaling factor applied to the output of the attention mechanism .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names over which data is sharded for distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names over which model parameters are sharded for distributed computation .
mesh ( Any ) : The SPMD mesh for parallel computation .
"""
2024-03-15 01:03:58 +03:00
num_q_heads : int
num_kv_heads : int
key_size : int
attn_output_multiplier : float = 1.0
mesh : Any = None
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data "
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model "
@hk.transparent
def __call__ (
self ,
inputs : jax . Array , # [B, T, D]
mask : jax . Array , # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
layer_memory : Optional [ KVMemory ] ,
) - > MHAOutput :
2024-03-19 00:53:06 +03:00
"""
Processes inputs through a Multi - Head Attention block .
This method applies multi - head attention to the inputs using the provided mask for attention scoring ,
and optionally utilizes cached memory for keys and values to enhance efficiency in autoregressive models .
Args :
inputs ( jax . Array ) : Input embeddings , shaped as [ batch_size , seq_length , embedding_dim ] .
mask ( jax . Array ) : Attention mask , shaped as [ batch_size , 1 , seq_length , seq_length ] , to control visibility between tokens .
layer_memory ( Optional [ KVMemory ] ) : Cached keys and values from previous steps for efficient attention computation in autoregressive decoding .
Returns :
MHAOutput : The output from the multi - head attention block , including the transformed embeddings and updated memory .
"""
2024-03-15 01:03:58 +03:00
_ , _ , model_size = inputs . shape
assert mask . ndim == 4 , f " shape: { mask . shape } "
assert mask . shape [ 2 ] in { 1 , inputs . shape [ 1 ] } , str ( mask . shape )
assert mask . shape [ 3 ] in { 1 , inputs . shape [ 1 ] } , str ( mask . shape )
side_input = inputs
def attn_block ( query , key , value , mask , memory ) - > MHAOutput :
return MultiHeadAttention (
num_q_heads = self . num_q_heads ,
num_kv_heads = self . num_kv_heads ,
key_size = self . key_size ,
model_size = model_size ,
data_axis = self . data_axis ,
model_axis = self . model_axis ,
attn_output_multiplier = self . attn_output_multiplier ,
) (
query ,
key ,
value ,
mask ,
memory ,
mesh = self . mesh ,
)
attn_output = attn_block ( inputs , side_input , side_input , mask , layer_memory )
h_attn = attn_output . embeddings
return attn_output . _replace ( embeddings = h_attn )
@dataclass
class DenseBlock ( hk . Module ) :
2024-03-19 00:53:06 +03:00
"""
Implements a dense ( fully connected ) block within a transformer layer , typically following multi - head attention .
This block applies one or more linear transformations to its inputs , often including non - linear activations and potentially other operations like dropout for regularization .
Attributes :
num_q_heads ( int ) : The number of heads for the query in the preceding MHA layer .
num_kv_heads ( int ) : The number of key / value pairs per head in the MHA layer .
key_size ( int ) : The size of the keys in the MHA layer .
widening_factor ( float ) : Factor by which the dimensionality of the feed - forward network is increased .
sharding_constraint ( bool ) : Whether to apply a sharding constraint for distributed computation .
mesh ( Any ) : The SPMD mesh for parallel computation .
"""
2024-03-15 01:03:58 +03:00
num_q_heads : int
num_kv_heads : int
key_size : int
widening_factor : float = 4.0
sharding_constraint : bool = False
mesh : Any = None
@hk.transparent
def __call__ (
self ,
inputs : jax . Array , # [B, T, D]
) - > jax . Array : # [B, T, D]
2024-03-19 00:53:06 +03:00
"""
Applies a series of dense transformations to the inputs .
This method constitutes the feedforward network part of a transformer layer , applying linear transformations
followed by non - linear activations to model complex data relationships beyond what attention mechanisms capture .
Args :
inputs ( jax . Array ) : Input embeddings from the previous layer or block , shaped as [ batch_size , seq_length , embedding_dim ] .
Returns :
jax . Array : The output embeddings after applying the dense block transformations , maintaining the input shape .
"""
2024-03-15 01:03:58 +03:00
_ , _ , model_size = inputs . shape
h_v = Linear (
ffn_size (
model_size ,
self . widening_factor ,
) ,
with_bias = False ,
mesh = self . mesh ,
sharding = P ( " data " , " model " ) ,
name = " linear_v " ,
) ( inputs )
h_w1 = jax . nn . gelu (
Linear (
ffn_size (
model_size ,
self . widening_factor ,
) ,
with_bias = False ,
mesh = self . mesh ,
sharding = P ( " data " , " model " ) ,
) ( inputs )
)
h_dense = Linear (
model_size ,
with_bias = False ,
sharding = P ( " model " , " data " ) ,
mesh = self . mesh ,
shard_axis = 1 ,
) ( h_w1 * h_v )
return h_dense
@dataclass
class DecoderLayer ( hk . Module ) :
2024-03-19 00:53:06 +03:00
"""
Represents a single layer in the decoder stack of a transformer model . This layer processes input embeddings through
a multi - head attention mechanism followed by position - wise feed - forward networks , with normalization and skip connections
applied as per the standard transformer architecture .
Attributes :
num_q_heads ( int ) : Number of query heads for multi - head attention .
num_kv_heads ( int ) : Number of key / value pairs per attention head .
key_size ( int ) : Size of keys in the attention mechanism .
num_layers ( int ) : Total number of transformer layers .
num_experts ( int ) : Number of experts in the Mixture of Experts layer , if used .
layer_index ( Optional [ int ] ) : Index of this layer within the overall model ; used for layer - specific configurations or logging .
num_selected_experts ( int ) : Number of experts selected for each input in the MoE layer .
widening_factor ( float ) : Factor by which the dimensionality of the feed - forward network is increased .
name ( Optional [ str ] ) : An optional name for the layer .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names for data sharding in distributed computation .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names for model parameter sharding in distributed computation .
shard_activations ( bool ) : Whether activations should be sharded across devices .
attn_output_multiplier ( float ) : Scaling factor for the output of the attention mechanism .
mesh ( Any ) : SPMD mesh for parallel computation .
"""
2024-03-15 01:03:58 +03:00
num_q_heads : int
num_kv_heads : int
key_size : int
num_layers : int
# MoE.
num_experts : int
layer_index : Optional [ int ] = None
num_selected_experts : int = 1
widening_factor : float = 4.0
name : Optional [ str ] = None
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data "
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model "
shard_activations : bool = False
attn_output_multiplier : float = 1.0
mesh : Any = None
def __call__ (
self ,
inputs : jax . Array , # [B, T, D]
mask : jax . Array , # [B, 1, T, T] or [B, 1, 1, T]
padding_mask : Optional [ jax . Array ] ,
layer_memory : Optional [ KVMemory ] ,
) - > DecoderOutput :
2024-03-18 23:57:49 +03:00
"""
Transforms input embedding sequences to output embedding sequences .
Processes input embeddings through a single layer of the decoder .
This method applies multi - head attention followed by position - wise feed - forward networks ,
including any necessary normalization and skip connections , as per the transformer architecture .
Args :
inputs ( jax . Array ) : Input embeddings , shaped [ batch_size , seq_length , model_dim ] .
mask ( jax . Array ) : Attention mask , shaped [ batch_size , 1 , seq_length , seq_length ] , used to prevent
attention to future positions .
padding_mask ( Optional [ jax . Array ] ) : Mask indicating which positions are padding tokens ,
to exclude them from attention calculations .
layer_memory ( Optional [ KVMemory ] ) : Memory state for storing past key / value pairs for efficient
autoregressive decoding .
Returns :
DecoderOutput : Named tuple containing output embeddings and updated memory state .
"""
2024-03-15 01:03:58 +03:00
def layer_norm ( x ) :
return hk_rms_norm ( x )
if self . shard_activations :
sharding = P ( self . data_axis , None , self . model_axis )
else :
sharding = P ( self . data_axis , None )
h = with_sharding_constraint ( inputs , sharding )
attn_output = MHABlock (
num_q_heads = self . num_q_heads ,
num_kv_heads = self . num_kv_heads ,
key_size = self . key_size ,
attn_output_multiplier = self . attn_output_multiplier ,
mesh = self . mesh ,
data_axis = self . data_axis ,
model_axis = self . model_axis ,
) ( layer_norm ( h ) , mask , layer_memory )
h_attn = attn_output . embeddings
h_attn = layer_norm ( h_attn )
h + = h_attn
h = with_sharding_constraint ( h , sharding )
def base_dense_block ( h ) :
h = DenseBlock (
num_q_heads = self . num_q_heads ,
num_kv_heads = self . num_kv_heads ,
key_size = self . key_size ,
widening_factor = self . widening_factor ,
sharding_constraint = False ,
mesh = self . mesh ,
) ( h )
return h
if self . num_experts > 1 :
rank_logger . debug ( " Using MoE! " )
router = Router (
num_selected_experts = self . num_selected_experts ,
shard_activations = self . shard_activations ,
data_axis = self . data_axis ,
model_axis = self . model_axis ,
mesh = self . mesh ,
)
h_dense = MoELayer (
num_experts = self . num_experts ,
mesh = self . mesh ,
layer_fn = base_dense_block ,
router = router ,
shard_activations = self . shard_activations ,
data_axis = self . data_axis ,
model_axis = self . model_axis ,
) ( layer_norm ( h ) , padding_mask )
else :
h_dense = base_dense_block ( layer_norm ( h ) )
h_dense = layer_norm ( h_dense )
h + = h_dense
h = with_sharding_constraint ( h , sharding )
return DecoderOutput (
embeddings = h ,
memory = attn_output . memory ,
)
class LanguageModelOutput ( NamedTuple ) :
2024-03-19 00:53:06 +03:00
"""
Represents the output of the language model after processing an input sequence of tokens .
This output encapsulates both the logits representing the model ' s predictions for the next token
in the sequence and the updated model state , which includes any memory or state information
that needs to be carried over for generating subsequent tokens .
Attributes :
logits ( jax . Array ) : The logits for the next token predictions , shaped as [ batch_size , sequence_length , vocab_size ] .
model_state ( Any ) : The updated state of the model after processing the input sequence , which may include
memory states for layers that utilize recurrence or caching for efficiency .
"""
2024-03-15 01:03:58 +03:00
logits : jax . Array
model_state : Any
class InOutEmbed ( hk . Embed ) :
2024-03-19 00:53:06 +03:00
"""
A module for embedding input tokens into a low - dimensional space continuous vector space and for projecting the outputs of
a transformer back into the vocabulary space . This module supports tying the weights between the input
embedding and the output projection for parameter efficiency .
Attributes :
vocab_size ( Optional [ int ] ) : The size of the vocabulary .
embed_dim ( Optional [ int ] ) : The dimensionality of the embedding vectors .
sharding ( Optional [ PartitionSpec ] ) : Specifies how the embedding parameters should be sharded across
devices for distributed computation .
name ( Optional [ str ] ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
def __init__ (
self ,
vocab_size : Optional [ int ] = None ,
embed_dim : Optional [ int ] = None ,
sharding : Optional [ P ] = None ,
name : Optional [ str ] = None ,
) :
2024-03-19 00:53:06 +03:00
"""
Initializes an embedding module that can be used for both input token embeddings and output logits projection in a transformer model .
This shared embedding layer helps reduce the number of parameters by tying the weights between the input embedding and the output projection layers .
Args :
vocab_size ( Optional [ int ] ) : The size of the vocabulary .
embed_dim ( Optional [ int ] ) : The dimensionality of the embedding vectors .
sharding ( Optional [ PartitionSpec ] ) : The sharding specification for distributing the embedding parameters across devices .
name ( Optional [ str ] ) : An optional name for the module .
"""
2024-03-15 01:03:58 +03:00
super ( ) . __init__ (
vocab_size = vocab_size ,
embed_dim = embed_dim ,
name = name ,
)
self . sharding = sharding
@property
def embeddings ( self ) :
2024-03-19 00:53:06 +03:00
"""
Retrieves the embedding matrix from the module ' s parameters.
This method is useful for operations that need direct access to the embeddings , such as output projection in language models .
Returns :
jax . Array : The embedding matrix , shaped as [ vocab_size , embed_dim ] .
"""
2024-03-15 01:03:58 +03:00
embed_mat = hk . get_parameter (
" embeddings " ,
[ self . vocab_size , self . embed_dim ] ,
dtype = jnp . float32 ,
init = hk . initializers . Constant ( 0 ) ,
)
if self . sharding :
embed_mat = with_sharding_constraint ( embed_mat , self . sharding )
return embed_mat
def decode (
self ,
inputs : jax . Array ,
) - > jax . Array :
2024-03-19 00:53:06 +03:00
"""
Projects transformer model outputs back into the vocabulary space using the transposed embedding matrix .
This method effectively performs the inverse of the embedding operation , converting model output embeddings into logits over the vocabulary .
Args :
inputs ( jax . Array ) : The output embeddings from the transformer model , shaped as [ batch_size , seq_length , embed_dim ] .
Returns :
jax . Array : The logits over the vocabulary for each token position , shaped as [ batch_size , seq_length , vocab_size ] .
"""
2024-03-15 01:03:58 +03:00
return jnp . dot ( inputs , self . embeddings . T . astype ( inputs . dtype ) )
@dataclass
class LanguageModelConfig :
2024-03-18 23:57:49 +03:00
"""
Configuration class for an autoregressive language model based on the Transformer architecture .
Attributes :
model ( TransformerConfig ) : The transformer model configuration .
vocab_size ( int ) : The size of the vocabulary .
pad_token ( int ) : The token used for padding sequences .
eos_token ( int ) : The end - of - sentence token .
sequence_len ( int ) : The maximum sequence length the model can handle .
model_size ( int ) : The dimensionality of the model embeddings .
embedding_init_scale ( float ) : Initial scale for embedding parameter initialization .
embedding_multiplier_scale ( float ) : Multiplier for scaling the embedding vectors .
output_multiplier_scale ( float ) : Multiplier for scaling the output logits .
name ( Optional [ str ] ) : Name of the language model configuration .
fprop_dtype ( Any ) : Data type for forward propagation computations .
model_type ( Optional [ str ] ) : Type of the model , if applicable .
init_scale_override ( Optional [ float ] ) : Override for the initial scale of parameters , if needed .
shard_embeddings ( bool ) : Whether to shard embeddings across the specified axes .
"""
2024-03-15 01:03:58 +03:00
model : Optional [ TransformerConfig ]
vocab_size : int
pad_token : int
eos_token : int
sequence_len : int
model_size : int = 0
embedding_init_scale : float = 1.0
embedding_multiplier_scale : float = 1.0
output_multiplier_scale : float = 1.0
name : Optional [ str ] = None
fprop_dtype : Any = jnp . bfloat16
model_type : Optional [ str ] = None
init_scale_override : Optional [ float ] = None
shard_embeddings : bool = True
_initialized = False
def initialize ( self ) :
# We cannot specify [] as a default value (it is mutable), hence None.
model_config = self . model
assert self . init_scale_override is None , (
" Overriding model initialize scale is supported only for predefined models. "
)
if self . model_size == 0 :
self . model_size = model_config . emb_size
assert self . model is not None , " Model could not be initialized. "
self . _initialized = True
return self
def make ( self , * args , * * kwargs ) :
if not self . _initialized :
logger . warning (
f " LanguageModel { self . name } is not initialized. Initializing for one replica. "
)
self . initialize ( )
return LanguageModel (
model = self . model . make ( * args , * * kwargs ) ,
config = self ,
fprop_dtype = self . fprop_dtype ,
mesh = kwargs . get ( " mesh " , None ) ,
)
def partition_rules ( self ) :
return LM_PARTITION_RULES + self . model . partition_rules ( )
def layer_norm ( x , model ) :
return hk_rms_norm ( x )
@dataclass
class LanguageModel ( hk . Module ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
A high - level module for autoregressive language modeling using a Transformer architecture . This module
integrates components such as embedding layers , transformer blocks , and output layers to process sequences
of tokens and generate predictions for the next tokens in the sequence .
2024-03-18 23:57:49 +03:00
2024-03-19 00:53:06 +03:00
The LanguageModel is designed for tasks such as text generation , where it can be used to produce coherent
and contextually relevant text based on a given prompt .
2024-03-18 23:57:49 +03:00
Attributes :
2024-03-19 00:53:06 +03:00
model ( Transformer ) : The core transformer model used for processing input token sequences .
config ( LanguageModelConfig ) : Configuration parameters for the language model , including details about
the architecture , embeddings , and output processing .
fprop_dtype ( Any ) : The data type to use for forward propagation calculations , typically set to jnp . bfloat16
for efficiency .
name ( Optional [ str ] ) : An optional name for the module . Useful for distinguishing between multiple instances .
mesh ( Any ) : The SPMD mesh for parallel computation , supporting distributed training and inference .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
model : " Transformer "
config : LanguageModelConfig
fprop_dtype : Any = jnp . bfloat16
name : Optional [ str ] = None
mesh : Any = None
def __call__ (
self ,
tokens : jax . Array ,
memory : Optional [ Memory ] = None ,
* ,
batch : Dict [ str , jax . Array ] = { } ,
last_hid_only : bool = False ,
length : Optional [ jax . Array ] = None ,
) - > LanguageModelOutput :
2024-03-18 23:57:49 +03:00
"""
Forward pass , producing a sequence of logits .
Generates logits for the next token predictions based on input tokens and optional memory state .
Args :
tokens ( jax . Array ) : Input tokens to the language model , shaped as [ batch_size , seq_length ] .
memory ( Optional [ Memory ] ) : Optional memory state from previous steps , for autoregressive generation .
batch ( Dict [ str , jax . Array ] ) : Additional batch information , unused here .
last_hid_only ( bool ) : If True , returns only the last hidden state instead of logits .
length ( Optional [ jax . Array ] ) : Specifies the length of each sequence in the batch for processing
only up to those lengths .
Returns :
LanguageModelOutput : A named tuple containing the logits for next token predictions and the
updated memory state .
"""
2024-03-15 01:03:58 +03:00
del batch # Unused.
config = self . config
input_mask = jnp . greater ( tokens , config . pad_token )
# Embed the input tokens and positions.
in_out_embed = InOutEmbed (
self . config . vocab_size ,
embed_dim = self . config . model_size ,
sharding = P ( None , ( " data " , " model " ) ) ,
)
input_embeddings = in_out_embed ( tokens ) . astype ( config . fprop_dtype )
input_embeddings = with_sharding_constraint (
input_embeddings , P ( " data " , None , self . model . model_axis )
)
input_embeddings * = config . embedding_multiplier_scale
model_output = self . model (
input_embeddings ,
input_mask ,
memory = memory ,
) # [B, T, D]
embeddings , model_state = model_output . embeddings , model_output . memory
if self . model . shard_activations :
embeddings = with_sharding_constraint (
embeddings , P ( " data " , None , self . model . model_axis )
)
else :
embeddings = with_sharding_constraint ( embeddings , P ( " data " , None ) )
rank_logger . debug ( f " Final embedding shape: { embeddings . shape } " )
embeddings = layer_norm ( embeddings , self . model )
assert embeddings . dtype == self . fprop_dtype
if last_hid_only :
last_step = jnp . maximum ( jnp . sum ( input_mask . astype ( jnp . int32 ) , axis = 1 ) - 1 , 0 )
last_hid = jax . vmap ( lambda x , i : x [ i ] , in_axes = 0 , out_axes = 0 ) ( embeddings , last_step )
return last_hid
if length is not None :
last_step = jnp . maximum ( length . astype ( jnp . int32 ) - 1 , 0 )
embeddings = jax . vmap ( lambda x , i : x [ i ] , in_axes = 0 , out_axes = 0 ) ( embeddings , last_step )
embeddings = jnp . expand_dims ( embeddings , axis = 1 )
# Decode the embeddings (here, we use tied weights).
rank_logger . info ( embeddings . shape )
out = in_out_embed . decode ( embeddings )
rank_logger . info ( out . shape )
out * = config . output_multiplier_scale
if self . model . shard_activations :
out = with_sharding_constraint ( out , P ( " data " , None , self . model . model_axis ) )
else :
out = with_sharding_constraint ( out , P ( " data " , None ) )
return LanguageModelOutput (
logits = out ,
model_state = model_state ,
)
def init_memory ( self , batch_size : int , seq_len : int , dtype = jnp . bfloat16 ) :
2024-03-19 00:53:06 +03:00
"""
Initializes the memory for the language model , suitable for autoregressive sequence generation tasks .
Args :
batch_size ( int ) : The number of sequences for which to initialize memory .
seq_len ( int ) : The length of sequences for memory allocation .
dtype ( Any ) : Data type for the memory arrays , typically jnp . bfloat16 .
Returns :
Memory : The initialized memory structure for storing keys , values , and decoding steps across transformer layers .
"""
2024-03-15 01:03:58 +03:00
return self . model . init_memory ( batch_size = batch_size , sequence_len = seq_len , dtype = dtype )
def prefill_memory ( self , prompts , memory ) :
2024-03-19 00:53:06 +03:00
"""
Optionally pre - fills the transformer ' s memory with information from a provided prompt, enhancing efficiency
in subsequent autoregressive generation steps by caching necessary computations from the prompt processing .
Args :
prompts ( jax . Array ) : The prompt tokens from which to generate subsequent tokens .
memory ( Memory ) : The memory state to update with information derived from the prompts .
Returns :
Tuple [ jax . Array , Memory ] : The logits produced by processing the prompts and the updated memory state .
"""
2024-03-15 01:03:58 +03:00
# Pad to the left and right align?
# Basically assume prompt is already padded
model_output = self ( prompts , memory = memory )
return model_output . logits , model_output . model_state
@dataclass
class Transformer ( hk . Module ) :
2024-03-18 23:57:49 +03:00
"""
2024-03-19 00:53:06 +03:00
Core transformer module that implements the foundational architecture of a transformer - based model . This module
is capable of processing sequences of embeddings through multiple layers of self - attention and feed - forward
networks , optionally including advanced techniques like mixture of experts ( MoE ) and activation sharding
for efficient large - scale parallel computation .
2024-03-18 23:57:49 +03:00
Attributes :
2024-03-19 00:53:06 +03:00
num_q_heads ( int ) : Number of heads in the query part of the multi - head attention mechanism .
num_kv_heads ( int ) : Number of heads for the keys and values in the multi - head attention .
key_size ( int ) : Dimensionality of the key ( and query ) vectors in the attention mechanism .
widening_factor ( float ) : Factor by which to widen the dimensionality of the feed - forward network relative to the embeddings .
init_scale ( float ) : Initial scale for parameter initialization .
2024-03-18 23:57:49 +03:00
mesh ( Any ) : The SPMD mesh for parallel computation .
attn_output_multiplier ( float ) : Multiplier for the output of the attention mechanism .
2024-03-19 00:53:06 +03:00
shard_activations ( bool ) : Whether to shard activations across devices in distributed settings .
num_layers ( int ) : Number of transformer layers to stack in the model .
num_experts ( int ) : Number of experts in the MoE layer , if used .
num_selected_experts ( int ) : Number of experts selected for each input token in the MoE layer .
data_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names for sharding data across devices .
model_axis ( Union [ str , Tuple [ str , . . . ] ] ) : Axis names for sharding model parameters across devices .
name ( Optional [ str ] ) : An optional name for the module .
2024-03-18 23:57:49 +03:00
"""
2024-03-15 01:03:58 +03:00
num_q_heads : int
num_kv_heads : int
key_size : int
widening_factor : float
init_scale : float
mesh : Any
attn_output_multiplier : float
shard_activations : bool
num_layers : int
# MoE
num_experts : int
num_selected_experts : int
name : Optional [ str ] = None
# Used for activation sharding
data_axis : Union [ str , Tuple [ str , . . . ] ] = " data "
model_axis : Union [ str , Tuple [ str , . . . ] ] = " model "
def init_memory ( self , batch_size : int , sequence_len : int , dtype = jnp . bfloat16 ) :
2024-03-18 23:57:49 +03:00
"""
Initializes the memory state for the transformer model .
This is particularly useful for autoregressive tasks where past key and value pairs are cached
to improve efficiency in generating sequences .
Args :
batch_size ( int ) : The batch size for which to initialize memory states .
sequence_len ( int ) : The sequence length for initializing the size of memory buffers .
dtype ( Any ) : The data type for the memory arrays , typically jnp . bfloat16 for efficiency .
Returns :
Memory : A named tuple representing the initialized memory state for each layer .
"""
2024-03-15 01:03:58 +03:00
return Memory (
layers = init_layer_memories (
batch_size ,
sequence_len ,
self . num_kv_heads ,
self . key_size ,
self . num_layers ,
step = jnp . zeros ( batch_size , dtype = jnp . int32 ) ,
dtype = dtype ,
) ,
)
def __call__ (
self ,
embeddings : jax . Array , # [B, T, D]
mask : jax . Array , # [B, T]
memory : Optional [ Memory ] ,
) - > TransformerOutput :
2024-03-18 23:57:49 +03:00
"""
Processes input embeddings through the transformer model .
Transforms input embedding sequences to output embedding sequences .
Args :
embeddings ( jax . Array ) : Input embeddings to be processed by the transformer , shaped as
[ batch_size , seq_length , model_dim ] .
mask ( jax . Array ) : Mask indicating valid positions within the input , to control which positions
are allowed to attend to each other , shaped as [ batch_size , seq_length ] .
memory ( Optional [ Memory ] ) : Optional memory state for the transformer to support autoregressive
decoding or similar use cases .
Returns :
TransformerOutput : A named tuple containing the transformed embeddings and the final state
of the memory after processing .
"""
2024-03-15 01:03:58 +03:00
fprop_dtype = embeddings . dtype
_ , seq_len , model_size = embeddings . shape
padding_mask = mask . copy ( )
mask = mask [ : , None , None , : ] # [B, H=1, T'=1, T]
# Compute causal mask for autoregressive sequence modelling.
causal_mask = jnp . tril ( jnp . ones ( ( 1 , 1 , seq_len , seq_len ) ) ) . astype (
fprop_dtype
) # [B=1, H=1, T, T]
mask = mask * causal_mask # [B, H=1, T, T]
h = embeddings
kv_memories = [ ]
def block (
h ,
mask ,
padding_mask ,
memory ,
layer_index : Optional [ int ] = None ,
widening_factor : Optional [ int ] = None ,
name : Optional [ str ] = None ,
) - > DecoderOutput :
return DecoderLayer (
num_q_heads = self . num_q_heads ,
num_kv_heads = self . num_kv_heads ,
key_size = self . key_size ,
widening_factor = widening_factor or self . widening_factor ,
num_layers = self . num_layers ,
mesh = self . mesh ,
data_axis = self . data_axis ,
model_axis = self . model_axis ,
attn_output_multiplier = self . attn_output_multiplier ,
shard_activations = self . shard_activations ,
# MoE.
num_experts = self . num_experts ,
num_selected_experts = self . num_selected_experts ,
name = name ,
layer_index = layer_index ,
) (
h ,
mask ,
padding_mask ,
memory ,
)
for i in range ( self . num_layers ) :
decoder_output = block (
h ,
mask ,
padding_mask ,
memory . layers [ i ] if memory else None ,
layer_index = i ,
name = f " decoder_layer_ { i } " ,
)
h , new_kv_memory = (
decoder_output . embeddings ,
decoder_output . memory ,
)
kv_memories . append ( new_kv_memory )
return TransformerOutput (
embeddings = h ,
memory = Memory ( layers = kv_memories ) ,
)