Added docstrings to multiple modules and methods

This commit is contained in:
Carlos D. Escobar-Valbuena 2024-03-18 15:57:49 -05:00 committed by GitHub
parent d6d9447e2d
commit 429a83e5d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

363
model.py
View File

@ -36,6 +36,13 @@ rank_logger = logging.getLogger("rank")
@dataclass @dataclass
class QuantizedWeight8bit: class QuantizedWeight8bit:
"""
Represents an 8-bit quantized weight for neural network parameters.
Attributes:
weight (jnp.array): The quantized weights.
scales (jnp.array): The scale factors used for quantization.
"""
weight: jnp.array weight: jnp.array
scales: jnp.array scales: jnp.array
@ -52,13 +59,26 @@ tree_util.register_pytree_node(
class TrainingState(NamedTuple): class TrainingState(NamedTuple):
"""Container for the training state.""" """Container for the training state, encapsulating model parameters.
Attributes:
params (hk.Params): The parameters of the model.
"""
params: hk.Params params: hk.Params
def _match(qs, ks): def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks.""" """
Checks if regex patterns in qs match any contiguous subsequence of strings in ks.
Args:
qs (Sequence[str]): A sequence of regex patterns to match.
ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched.
Returns:
bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise.
"""
# compile regexes and force complete match # compile regexes and force complete match
qts = tuple(map(lambda x: re.compile(x + "$"), qs)) qts = tuple(map(lambda x: re.compile(x + "$"), qs))
for i in range(len(ks) - len(qs) + 1): for i in range(len(ks) - len(qs) + 1):
@ -69,6 +89,16 @@ def _match(qs, ks):
def with_sharding_constraint(x, constraint): def with_sharding_constraint(x, constraint):
"""
Applies a sharding constraint to a JAX array, if a physical mesh is available.
Args:
x (jax.Array): The array to apply the sharding constraint to.
constraint (PartitionSpec): The sharding constraint to apply.
Returns:
jax.Array: The array with the sharding constraint applied if a physical mesh is present.
"""
if jax.experimental.maps.thread_resources.env.physical_mesh.empty: if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
return x return x
else: else:
@ -76,6 +106,15 @@ def with_sharding_constraint(x, constraint):
def cast_bfloat16(x): def cast_bfloat16(x):
"""
Casts the input to bfloat16 type if it is of floating-point type.
Args:
x (jax.Array): The input array.
Returns:
jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise.
"""
if x.dtype.kind == "f": if x.dtype.kind == "f":
return x.astype(jnp.bfloat16) return x.astype(jnp.bfloat16)
else: else:
@ -83,6 +122,18 @@ def cast_bfloat16(x):
def ffn_size(emb_size, widening_factor): def ffn_size(emb_size, widening_factor):
"""
Calculates the size of the feed-forward network (FFN) based on the embedding size and a widening factor.
The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations.
Args:
emb_size (int): The size of the embeddings.
widening_factor (float): The factor by which to widen the FFN relative to the embedding size.
Returns:
int: The adjusted size of the FFN.
"""
_ffn_size = int(widening_factor * emb_size) * 2 // 3 _ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8 _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}") logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
@ -90,6 +141,20 @@ def ffn_size(emb_size, widening_factor):
def apply_rules(rules): def apply_rules(rules):
"""
Constructs a function to apply a set of sharding rules for transformer parameters.
This function is used to determine the sharding specifications for model parameters based on their roles
and positions within the model architecture.
Args:
rules (List[Tuple[Sequence[str], PartitionSpec]]): A list of tuples where each tuple contains a sequence
of strings representing the parameter path and the corresponding `PartitionSpec` to apply.
Returns:
Callable: A function that takes a parameter path and returns the appropriate `PartitionSpec` based
on the provided rules.
"""
def _apply_rules(path, value): def _apply_rules(path, value):
del value # Unused. del value # Unused.
@ -190,6 +255,21 @@ def init_layer_memories(
step: Optional[jax.Array] = None, step: Optional[jax.Array] = None,
dtype=jnp.bfloat16, dtype=jnp.bfloat16,
): ):
"""
Initializes memory slots for each layer in the transformer model.
Args:
batch_size (int): The size of the batch.
sequence_len (int): The length of the sequence.
num_kv_heads (int): The number of key/value pairs per head.
key_size (int): The size of each key.
num_layers (int): The number of layers in the transformer.
step (Optional[jax.Array]): The initial step for the memory, defaults to None.
dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16.
Returns:
List[KVMemory]: A list of initialized KVMemory instances for each layer.
"""
return [ return [
KVMemory( KVMemory(
k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype), k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
@ -206,6 +286,17 @@ class Memory(NamedTuple):
class Router(hk.Module): class Router(hk.Module):
"""
A module for routing inputs to experts in a Mixture of Experts (MoE) layer.
Attributes:
num_selected_experts (int): Number of experts to select for each input.
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
shard_activations (bool): If True, shard activations according to the data and model axes.
mesh (Any): The SPMD mesh for parallel computation.
name (str): The name of the module.
"""
def __init__( def __init__(
self, self,
num_selected_experts: int, num_selected_experts: int,
@ -234,6 +325,19 @@ class Router(hk.Module):
padding_mask: Optional[jax.Array], padding_mask: Optional[jax.Array],
num_experts: int, num_experts: int,
): ):
"""
Computes the routing probabilities for directing inputs to the appropriate experts.
Args:
inputs (jax.Array): Input data to be routed, shaped as [batch_size, ..., input_dim].
padding_mask (Optional[jax.Array]): An optional mask indicating padded elements in the input,
shaped as [batch_size, seq_length], where padded positions are False.
num_experts (int): The total number of experts available for routing.
Returns:
A tuple containing routing probabilities, routing logits, and a dummy value for compatibility,
shaped as ([batch_size, seq_length, num_experts], [batch_size, seq_length, num_experts], int).
"""
# Using fp32 for the routing prob computation. # Using fp32 for the routing prob computation.
inputs = jax.lax.convert_element_type(inputs, jnp.float32) inputs = jax.lax.convert_element_type(inputs, jnp.float32)
@ -270,6 +374,19 @@ class Router(hk.Module):
class MoELayer(hk.Module): class MoELayer(hk.Module):
"""
A module implementing a Mixture of Experts (MoE) layer.
Attributes:
num_experts (int): The number of experts in the MoE layer.
layer_fn (Callable): The function to be applied by each expert.
router (Router): The router that routes inputs to experts.
mesh (Any): The SPMD mesh for parallel computation.
shard_activations (bool): If True, shard activations across data and model axes.
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
name (Optional[str]): The name of the module.
"""
def __init__( def __init__(
self, self,
num_experts: int, num_experts: int,
@ -292,6 +409,18 @@ class MoELayer(hk.Module):
@hk.transparent @hk.transparent
def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None): def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
"""
Handles the inference call to the MoE layer, distributing inputs to selected experts based on routing.
Args:
inputs (jax.Array): Input data to be processed, shaped as [batch_size, seq_length, input_dim].
padding_mask (Optional[jax.Array]): An optional mask for the inputs, where False indicates
positions that should not be processed (e.g., padding), shaped as [batch_size, seq_length].
Returns:
jax.Array: The processed outputs after passing through the selected experts, shaped as
[batch_size, seq_length, output_dim].
"""
routing_probs, _, _ = self.router.compute_routing_prob( routing_probs, _, _ = self.router.compute_routing_prob(
inputs, padding_mask, self.num_experts inputs, padding_mask, self.num_experts
) )
@ -401,24 +530,65 @@ class MoELayer(hk.Module):
class MHAOutput(NamedTuple): class MHAOutput(NamedTuple):
"""Outputs of the multi-head attention operation.""" """
Represents the output of the Multi-Head Attention (MHA) operation.
Attributes:
embeddings (jax.Array): The output embeddings from the MHA layer.
memory (Any): The updated memory state post-attention operation.
"""
embeddings: jax.Array embeddings: jax.Array
memory: Any memory: Any
class DecoderOutput(NamedTuple): class DecoderOutput(NamedTuple):
"""
Encapsulates the output of a decoder layer within the transformer model.
Attributes:
embeddings (jax.Array): The embeddings produced by the decoder layer.
memory (Any): The updated memory state after processing by the decoder layer.
"""
embeddings: jax.Array embeddings: jax.Array
memory: Any memory: Any
class TransformerOutput(NamedTuple): class TransformerOutput(NamedTuple):
"""
Represents the final output of the transformer model.
Attributes:
embeddings (jax.Array): The final output embeddings from the transformer.
memory (Any): The final state of the memory after passing through the transformer layers.
"""
embeddings: jax.Array embeddings: jax.Array
memory: Any memory: Any
@dataclass @dataclass
class TransformerConfig: class TransformerConfig:
"""
Configuration class for a Transformer model specifying the model's architecture and settings.
Attributes:
emb_size (int): Embedding size used in the transformer.
key_size (int): Size of the keys in the multi-head attention mechanism.
num_q_heads (int): Number of query heads in the multi-head attention.
num_kv_heads (int): Number of key/value pairs per attention head.
num_layers (int): Number of layers in the transformer model.
vocab_size (int): The size of the vocabulary.
widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size.
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
name (Optional[str]): Name of the transformer configuration.
num_experts (int): Number of experts in a mixture of experts layer.
capacity_factor (float): Capacity factor for routing in MoE layers.
num_selected_experts (int): Number of experts selected in each MoE layer.
init_scale (float): Initial scale for parameter initialization.
shard_activations (bool): If True, activations will be sharded across the specified axes.
data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded.
model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded.
"""
emb_size: int emb_size: int
key_size: int key_size: int
num_q_heads: int num_q_heads: int
@ -523,6 +693,20 @@ def make_attention_mask(
class Linear(hk.Linear): class Linear(hk.Linear):
"""
Extends Haiku's Linear layer with optional sharding for use in distributed settings.
This class allows specifying a `PartitionSpec` to shard the linear layer's weights across devices,
which can be beneficial in large-scale models processed over multiple devices or nodes.
Args:
output_size (int): The size of the output dimension.
with_bias (bool, optional): Whether to include a bias term. Defaults to True.
sharding (Optional[P], optional): The sharding specification for distributing the layer's parameters.
mesh (Any, optional): The SPMD mesh for parallel computation. Defaults to None.
name (Optional[str], optional): An optional name for this module. Defaults to None.
shard_axis (int, optional): The axis along which to shard the input data. Defaults to 0.
"""
def __init__( def __init__(
self, self,
output_size: int, output_size: int,
@ -585,7 +769,21 @@ class Linear(hk.Linear):
class RMSNorm(hk.RMSNorm): class RMSNorm(hk.RMSNorm):
"""
Implements Root Mean Square Layer Normalization.
This variant of layer normalization scales inputs by the root mean square of their elements, optionally
including a learnable scaling factor. It supports specifying a `PartitionSpec` for sharding the scale
parameters across devices in distributed settings.
Args:
axis (Union[int, Sequence[int], slice]): The dimensions to normalize over.
eps (float, optional): A small constant added to the denominator to improve numerical stability.
Defaults to 1e-5.
name (Optional[str], optional): An optional name for this module. Defaults to None.
create_scale (bool, optional): Whether to include a learnable scaling factor. Defaults to True.
sharding (Optional[P], optional): The sharding specification for the scale parameter. Defaults to None.
"""
def __init__( def __init__(
self, self,
axis: Union[int, Sequence[int], slice], axis: Union[int, Sequence[int], slice],
@ -633,12 +831,19 @@ def rotate_half(
class RotaryEmbedding(hk.Module): class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor, """
Implements Rotary Position Embedding (RoPE) to the input sequence tensor,
as described in https://arxiv.org/abs/2104.09864. as described in https://arxiv.org/abs/2104.09864.
Attributes: RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their
dim (int): Dimensionality of the feature vectors position in the sequence. This approach is designed to preserve the relative positional information across
base_exponent (int): Base exponent to compute embeddings from different sequence lengths and tasks.
Args:
dim (int): The dimensionality of the embeddings to be rotated, must be even.
name (Optional[str], optional): An optional name for this module. Defaults to None.
base_exponent (int, optional): The base of the exponent used to calculate rotary frequencies.
Defaults to 10000.
""" """
def __init__( def __init__(
@ -726,6 +931,22 @@ class MultiHeadAttention(hk.Module):
kv_memory: Optional[KVMemory] = None, kv_memory: Optional[KVMemory] = None,
mesh: Any = None, mesh: Any = None,
) -> MHAOutput: ) -> MHAOutput:
"""
Computes the multi-head attention over the input queries, keys, and values.
Args:
query (jax.Array): Query vectors, shaped as [batch_size, seq_length, model_dim].
key (Optional[jax.Array]): Key vectors. If None, uses query as key.
value (Optional[jax.Array]): Value vectors. If None, uses query as value.
mask (Optional[jax.Array]): An optional mask to prevent attention to certain positions,
shaped as [batch_size, 1, seq_length, seq_length].
kv_memory (Optional[KVMemory]): Optional memory for keys and values to support efficient
autoregressive decoding.
mesh (Any): The SPMD mesh for parallel computation, if applicable.
Returns:
MHAOutput: A named tuple containing the output embeddings and updated memory.
"""
# In shape hints below, we suppress the leading dims [...] for brevity. # In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B]. # Hence e.g. [A, B] should be read in every case as [..., A, B].
sequence_length = query.shape[1] sequence_length = query.shape[1]
@ -1034,7 +1255,25 @@ class DecoderLayer(hk.Module):
padding_mask: Optional[jax.Array], padding_mask: Optional[jax.Array],
layer_memory: Optional[KVMemory], layer_memory: Optional[KVMemory],
) -> DecoderOutput: ) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences.""" """
Transforms input embedding sequences to output embedding sequences.
Processes input embeddings through a single layer of the decoder.
This method applies multi-head attention followed by position-wise feed-forward networks,
including any necessary normalization and skip connections, as per the transformer architecture.
Args:
inputs (jax.Array): Input embeddings, shaped [batch_size, seq_length, model_dim].
mask (jax.Array): Attention mask, shaped [batch_size, 1, seq_length, seq_length], used to prevent
attention to future positions.
padding_mask (Optional[jax.Array]): Mask indicating which positions are padding tokens,
to exclude them from attention calculations.
layer_memory (Optional[KVMemory]): Memory state for storing past key/value pairs for efficient
autoregressive decoding.
Returns:
DecoderOutput: Named tuple containing output embeddings and updated memory state.
"""
def layer_norm(x): def layer_norm(x):
return hk_rms_norm(x) return hk_rms_norm(x)
@ -1145,7 +1384,25 @@ class InOutEmbed(hk.Embed):
@dataclass @dataclass
class LanguageModelConfig: class LanguageModelConfig:
"""An autoregressive transformer-based language model.""" """
Configuration class for an autoregressive language model based on the Transformer architecture.
Attributes:
model (TransformerConfig): The transformer model configuration.
vocab_size (int): The size of the vocabulary.
pad_token (int): The token used for padding sequences.
eos_token (int): The end-of-sentence token.
sequence_len (int): The maximum sequence length the model can handle.
model_size (int): The dimensionality of the model embeddings.
embedding_init_scale (float): Initial scale for embedding parameter initialization.
embedding_multiplier_scale (float): Multiplier for scaling the embedding vectors.
output_multiplier_scale (float): Multiplier for scaling the output logits.
name (Optional[str]): Name of the language model configuration.
fprop_dtype (Any): Data type for forward propagation computations.
model_type (Optional[str]): Type of the model, if applicable.
init_scale_override (Optional[float]): Override for the initial scale of parameters, if needed.
shard_embeddings (bool): Whether to shard embeddings across the specified axes.
"""
model: Optional[TransformerConfig] model: Optional[TransformerConfig]
vocab_size: int vocab_size: int
@ -1200,7 +1457,20 @@ def layer_norm(x, model):
@dataclass @dataclass
class LanguageModel(hk.Module): class LanguageModel(hk.Module):
"""An autoregressive transformer-based language model.""" """
A transformer-based language model for generating or evaluating sequences of tokens.
This class encapsulates a transformer model and provides methods for its initialization,
running the model forward to generate logits, and handling memory states for efficient
autoregressive generation.
Attributes:
model (Transformer): The underlying transformer model.
config (LanguageModelConfig): Configuration for the language model.
fprop_dtype (Any): Data type for forward propagation computations.
name (Optional[str]): Optional name for the module.
mesh (Any): The SPMD mesh for parallel computation, if applicable.
"""
model: "Transformer" model: "Transformer"
config: LanguageModelConfig config: LanguageModelConfig
@ -1217,7 +1487,22 @@ class LanguageModel(hk.Module):
last_hid_only: bool = False, last_hid_only: bool = False,
length: Optional[jax.Array] = None, length: Optional[jax.Array] = None,
) -> LanguageModelOutput: ) -> LanguageModelOutput:
"""Forward pass, producing a sequence of logits.""" """
Forward pass, producing a sequence of logits.
Generates logits for the next token predictions based on input tokens and optional memory state.
Args:
tokens (jax.Array): Input tokens to the language model, shaped as [batch_size, seq_length].
memory (Optional[Memory]): Optional memory state from previous steps, for autoregressive generation.
batch (Dict[str, jax.Array]): Additional batch information, unused here.
last_hid_only (bool): If True, returns only the last hidden state instead of logits.
length (Optional[jax.Array]): Specifies the length of each sequence in the batch for processing
only up to those lengths.
Returns:
LanguageModelOutput: A named tuple containing the logits for next token predictions and the
updated memory state.
"""
del batch # Unused. del batch # Unused.
config = self.config config = self.config
@ -1290,7 +1575,30 @@ class LanguageModel(hk.Module):
@dataclass @dataclass
class Transformer(hk.Module): class Transformer(hk.Module):
"""A transformer stack.""" """
Core transformer model class implementing a stack of transformer layers.
This class is designed to be flexible and configurable, supporting features like
multi-head attention, feed-forward networks, and optional mixture of experts layers.
It is capable of processing sequences of embeddings and returning transformed sequences
of embeddings, along with updated memory states for autoregressive tasks.
Attributes:
num_q_heads (int): Number of query heads for multi-head attention.
num_kv_heads (int): Number of key/value pairs per attention head.
key_size (int): Size of keys in the attention mechanism.
widening_factor (float): Widening factor for the dimensionality of the feed-forward network.
init_scale (float): Scale for parameter initialization.
mesh (Any): The SPMD mesh for parallel computation.
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
shard_activations (bool): Whether to shard activations across specified axes.
num_layers (int): Number of transformer layers.
num_experts (int): Number of experts for mixture of experts layers, if applicable.
num_selected_experts (int): Number of experts selected per token, for MoE layers.
name (Optional[str]): Name of the transformer model.
data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding.
model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding.
"""
num_q_heads: int num_q_heads: int
num_kv_heads: int num_kv_heads: int
@ -1311,6 +1619,20 @@ class Transformer(hk.Module):
model_axis: Union[str, Tuple[str, ...]] = "model" model_axis: Union[str, Tuple[str, ...]] = "model"
def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16): def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):
"""
Initializes the memory state for the transformer model.
This is particularly useful for autoregressive tasks where past key and value pairs are cached
to improve efficiency in generating sequences.
Args:
batch_size (int): The batch size for which to initialize memory states.
sequence_len (int): The sequence length for initializing the size of memory buffers.
dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency.
Returns:
Memory: A named tuple representing the initialized memory state for each layer.
"""
return Memory( return Memory(
layers=init_layer_memories( layers=init_layer_memories(
batch_size, batch_size,
@ -1329,7 +1651,22 @@ class Transformer(hk.Module):
mask: jax.Array, # [B, T] mask: jax.Array, # [B, T]
memory: Optional[Memory], memory: Optional[Memory],
) -> TransformerOutput: ) -> TransformerOutput:
"""Transforms input embedding sequences to output embedding sequences.""" """
Processes input embeddings through the transformer model.
Transforms input embedding sequences to output embedding sequences.
Args:
embeddings (jax.Array): Input embeddings to be processed by the transformer, shaped as
[batch_size, seq_length, model_dim].
mask (jax.Array): Mask indicating valid positions within the input, to control which positions
are allowed to attend to each other, shaped as [batch_size, seq_length].
memory (Optional[Memory]): Optional memory state for the transformer to support autoregressive
decoding or similar use cases.
Returns:
TransformerOutput: A named tuple containing the transformed embeddings and the final state
of the memory after processing.
"""
fprop_dtype = embeddings.dtype fprop_dtype = embeddings.dtype
_, seq_len, model_size = embeddings.shape _, seq_len, model_size = embeddings.shape