mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-27 05:59:52 +03:00
Added docstrings to multiple modules and methods
This commit is contained in:
parent
d6d9447e2d
commit
429a83e5d9
363
model.py
363
model.py
@ -36,6 +36,13 @@ rank_logger = logging.getLogger("rank")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuantizedWeight8bit:
|
class QuantizedWeight8bit:
|
||||||
|
"""
|
||||||
|
Represents an 8-bit quantized weight for neural network parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight (jnp.array): The quantized weights.
|
||||||
|
scales (jnp.array): The scale factors used for quantization.
|
||||||
|
"""
|
||||||
weight: jnp.array
|
weight: jnp.array
|
||||||
scales: jnp.array
|
scales: jnp.array
|
||||||
|
|
||||||
@ -52,13 +59,26 @@ tree_util.register_pytree_node(
|
|||||||
|
|
||||||
|
|
||||||
class TrainingState(NamedTuple):
|
class TrainingState(NamedTuple):
|
||||||
"""Container for the training state."""
|
"""Container for the training state, encapsulating model parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
params (hk.Params): The parameters of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
params: hk.Params
|
params: hk.Params
|
||||||
|
|
||||||
|
|
||||||
def _match(qs, ks):
|
def _match(qs, ks):
|
||||||
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
"""
|
||||||
|
Checks if regex patterns in qs match any contiguous subsequence of strings in ks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qs (Sequence[str]): A sequence of regex patterns to match.
|
||||||
|
ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise.
|
||||||
|
"""
|
||||||
# compile regexes and force complete match
|
# compile regexes and force complete match
|
||||||
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
||||||
for i in range(len(ks) - len(qs) + 1):
|
for i in range(len(ks) - len(qs) + 1):
|
||||||
@ -69,6 +89,16 @@ def _match(qs, ks):
|
|||||||
|
|
||||||
|
|
||||||
def with_sharding_constraint(x, constraint):
|
def with_sharding_constraint(x, constraint):
|
||||||
|
"""
|
||||||
|
Applies a sharding constraint to a JAX array, if a physical mesh is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (jax.Array): The array to apply the sharding constraint to.
|
||||||
|
constraint (PartitionSpec): The sharding constraint to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
jax.Array: The array with the sharding constraint applied if a physical mesh is present.
|
||||||
|
"""
|
||||||
if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
|
if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
@ -76,6 +106,15 @@ def with_sharding_constraint(x, constraint):
|
|||||||
|
|
||||||
|
|
||||||
def cast_bfloat16(x):
|
def cast_bfloat16(x):
|
||||||
|
"""
|
||||||
|
Casts the input to bfloat16 type if it is of floating-point type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (jax.Array): The input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise.
|
||||||
|
"""
|
||||||
if x.dtype.kind == "f":
|
if x.dtype.kind == "f":
|
||||||
return x.astype(jnp.bfloat16)
|
return x.astype(jnp.bfloat16)
|
||||||
else:
|
else:
|
||||||
@ -83,6 +122,18 @@ def cast_bfloat16(x):
|
|||||||
|
|
||||||
|
|
||||||
def ffn_size(emb_size, widening_factor):
|
def ffn_size(emb_size, widening_factor):
|
||||||
|
"""
|
||||||
|
Calculates the size of the feed-forward network (FFN) based on the embedding size and a widening factor.
|
||||||
|
|
||||||
|
The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emb_size (int): The size of the embeddings.
|
||||||
|
widening_factor (float): The factor by which to widen the FFN relative to the embedding size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The adjusted size of the FFN.
|
||||||
|
"""
|
||||||
_ffn_size = int(widening_factor * emb_size) * 2 // 3
|
_ffn_size = int(widening_factor * emb_size) * 2 // 3
|
||||||
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
|
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
|
||||||
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
|
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
|
||||||
@ -90,6 +141,20 @@ def ffn_size(emb_size, widening_factor):
|
|||||||
|
|
||||||
|
|
||||||
def apply_rules(rules):
|
def apply_rules(rules):
|
||||||
|
"""
|
||||||
|
Constructs a function to apply a set of sharding rules for transformer parameters.
|
||||||
|
|
||||||
|
This function is used to determine the sharding specifications for model parameters based on their roles
|
||||||
|
and positions within the model architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rules (List[Tuple[Sequence[str], PartitionSpec]]): A list of tuples where each tuple contains a sequence
|
||||||
|
of strings representing the parameter path and the corresponding `PartitionSpec` to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A function that takes a parameter path and returns the appropriate `PartitionSpec` based
|
||||||
|
on the provided rules.
|
||||||
|
"""
|
||||||
def _apply_rules(path, value):
|
def _apply_rules(path, value):
|
||||||
del value # Unused.
|
del value # Unused.
|
||||||
|
|
||||||
@ -190,6 +255,21 @@ def init_layer_memories(
|
|||||||
step: Optional[jax.Array] = None,
|
step: Optional[jax.Array] = None,
|
||||||
dtype=jnp.bfloat16,
|
dtype=jnp.bfloat16,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initializes memory slots for each layer in the transformer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): The size of the batch.
|
||||||
|
sequence_len (int): The length of the sequence.
|
||||||
|
num_kv_heads (int): The number of key/value pairs per head.
|
||||||
|
key_size (int): The size of each key.
|
||||||
|
num_layers (int): The number of layers in the transformer.
|
||||||
|
step (Optional[jax.Array]): The initial step for the memory, defaults to None.
|
||||||
|
dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[KVMemory]: A list of initialized KVMemory instances for each layer.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
KVMemory(
|
KVMemory(
|
||||||
k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
|
k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
|
||||||
@ -206,6 +286,17 @@ class Memory(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Router(hk.Module):
|
class Router(hk.Module):
|
||||||
|
"""
|
||||||
|
A module for routing inputs to experts in a Mixture of Experts (MoE) layer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_selected_experts (int): Number of experts to select for each input.
|
||||||
|
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
|
||||||
|
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
|
||||||
|
shard_activations (bool): If True, shard activations according to the data and model axes.
|
||||||
|
mesh (Any): The SPMD mesh for parallel computation.
|
||||||
|
name (str): The name of the module.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_selected_experts: int,
|
num_selected_experts: int,
|
||||||
@ -234,6 +325,19 @@ class Router(hk.Module):
|
|||||||
padding_mask: Optional[jax.Array],
|
padding_mask: Optional[jax.Array],
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Computes the routing probabilities for directing inputs to the appropriate experts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (jax.Array): Input data to be routed, shaped as [batch_size, ..., input_dim].
|
||||||
|
padding_mask (Optional[jax.Array]): An optional mask indicating padded elements in the input,
|
||||||
|
shaped as [batch_size, seq_length], where padded positions are False.
|
||||||
|
num_experts (int): The total number of experts available for routing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing routing probabilities, routing logits, and a dummy value for compatibility,
|
||||||
|
shaped as ([batch_size, seq_length, num_experts], [batch_size, seq_length, num_experts], int).
|
||||||
|
"""
|
||||||
# Using fp32 for the routing prob computation.
|
# Using fp32 for the routing prob computation.
|
||||||
inputs = jax.lax.convert_element_type(inputs, jnp.float32)
|
inputs = jax.lax.convert_element_type(inputs, jnp.float32)
|
||||||
|
|
||||||
@ -270,6 +374,19 @@ class Router(hk.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MoELayer(hk.Module):
|
class MoELayer(hk.Module):
|
||||||
|
"""
|
||||||
|
A module implementing a Mixture of Experts (MoE) layer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_experts (int): The number of experts in the MoE layer.
|
||||||
|
layer_fn (Callable): The function to be applied by each expert.
|
||||||
|
router (Router): The router that routes inputs to experts.
|
||||||
|
mesh (Any): The SPMD mesh for parallel computation.
|
||||||
|
shard_activations (bool): If True, shard activations across data and model axes.
|
||||||
|
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
|
||||||
|
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
|
||||||
|
name (Optional[str]): The name of the module.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -292,6 +409,18 @@ class MoELayer(hk.Module):
|
|||||||
|
|
||||||
@hk.transparent
|
@hk.transparent
|
||||||
def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
|
def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
|
||||||
|
"""
|
||||||
|
Handles the inference call to the MoE layer, distributing inputs to selected experts based on routing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (jax.Array): Input data to be processed, shaped as [batch_size, seq_length, input_dim].
|
||||||
|
padding_mask (Optional[jax.Array]): An optional mask for the inputs, where False indicates
|
||||||
|
positions that should not be processed (e.g., padding), shaped as [batch_size, seq_length].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
jax.Array: The processed outputs after passing through the selected experts, shaped as
|
||||||
|
[batch_size, seq_length, output_dim].
|
||||||
|
"""
|
||||||
routing_probs, _, _ = self.router.compute_routing_prob(
|
routing_probs, _, _ = self.router.compute_routing_prob(
|
||||||
inputs, padding_mask, self.num_experts
|
inputs, padding_mask, self.num_experts
|
||||||
)
|
)
|
||||||
@ -401,24 +530,65 @@ class MoELayer(hk.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MHAOutput(NamedTuple):
|
class MHAOutput(NamedTuple):
|
||||||
"""Outputs of the multi-head attention operation."""
|
"""
|
||||||
|
Represents the output of the Multi-Head Attention (MHA) operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings (jax.Array): The output embeddings from the MHA layer.
|
||||||
|
memory (Any): The updated memory state post-attention operation.
|
||||||
|
"""
|
||||||
|
|
||||||
embeddings: jax.Array
|
embeddings: jax.Array
|
||||||
memory: Any
|
memory: Any
|
||||||
|
|
||||||
|
|
||||||
class DecoderOutput(NamedTuple):
|
class DecoderOutput(NamedTuple):
|
||||||
|
"""
|
||||||
|
Encapsulates the output of a decoder layer within the transformer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings (jax.Array): The embeddings produced by the decoder layer.
|
||||||
|
memory (Any): The updated memory state after processing by the decoder layer.
|
||||||
|
"""
|
||||||
embeddings: jax.Array
|
embeddings: jax.Array
|
||||||
memory: Any
|
memory: Any
|
||||||
|
|
||||||
|
|
||||||
class TransformerOutput(NamedTuple):
|
class TransformerOutput(NamedTuple):
|
||||||
|
"""
|
||||||
|
Represents the final output of the transformer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings (jax.Array): The final output embeddings from the transformer.
|
||||||
|
memory (Any): The final state of the memory after passing through the transformer layers.
|
||||||
|
"""
|
||||||
embeddings: jax.Array
|
embeddings: jax.Array
|
||||||
memory: Any
|
memory: Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TransformerConfig:
|
class TransformerConfig:
|
||||||
|
"""
|
||||||
|
Configuration class for a Transformer model specifying the model's architecture and settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
emb_size (int): Embedding size used in the transformer.
|
||||||
|
key_size (int): Size of the keys in the multi-head attention mechanism.
|
||||||
|
num_q_heads (int): Number of query heads in the multi-head attention.
|
||||||
|
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||||
|
num_layers (int): Number of layers in the transformer model.
|
||||||
|
vocab_size (int): The size of the vocabulary.
|
||||||
|
widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size.
|
||||||
|
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||||
|
name (Optional[str]): Name of the transformer configuration.
|
||||||
|
num_experts (int): Number of experts in a mixture of experts layer.
|
||||||
|
capacity_factor (float): Capacity factor for routing in MoE layers.
|
||||||
|
num_selected_experts (int): Number of experts selected in each MoE layer.
|
||||||
|
init_scale (float): Initial scale for parameter initialization.
|
||||||
|
shard_activations (bool): If True, activations will be sharded across the specified axes.
|
||||||
|
data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded.
|
||||||
|
model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded.
|
||||||
|
"""
|
||||||
emb_size: int
|
emb_size: int
|
||||||
key_size: int
|
key_size: int
|
||||||
num_q_heads: int
|
num_q_heads: int
|
||||||
@ -523,6 +693,20 @@ def make_attention_mask(
|
|||||||
|
|
||||||
|
|
||||||
class Linear(hk.Linear):
|
class Linear(hk.Linear):
|
||||||
|
"""
|
||||||
|
Extends Haiku's Linear layer with optional sharding for use in distributed settings.
|
||||||
|
|
||||||
|
This class allows specifying a `PartitionSpec` to shard the linear layer's weights across devices,
|
||||||
|
which can be beneficial in large-scale models processed over multiple devices or nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_size (int): The size of the output dimension.
|
||||||
|
with_bias (bool, optional): Whether to include a bias term. Defaults to True.
|
||||||
|
sharding (Optional[P], optional): The sharding specification for distributing the layer's parameters.
|
||||||
|
mesh (Any, optional): The SPMD mesh for parallel computation. Defaults to None.
|
||||||
|
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||||
|
shard_axis (int, optional): The axis along which to shard the input data. Defaults to 0.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_size: int,
|
output_size: int,
|
||||||
@ -585,7 +769,21 @@ class Linear(hk.Linear):
|
|||||||
|
|
||||||
|
|
||||||
class RMSNorm(hk.RMSNorm):
|
class RMSNorm(hk.RMSNorm):
|
||||||
|
"""
|
||||||
|
Implements Root Mean Square Layer Normalization.
|
||||||
|
|
||||||
|
This variant of layer normalization scales inputs by the root mean square of their elements, optionally
|
||||||
|
including a learnable scaling factor. It supports specifying a `PartitionSpec` for sharding the scale
|
||||||
|
parameters across devices in distributed settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (Union[int, Sequence[int], slice]): The dimensions to normalize over.
|
||||||
|
eps (float, optional): A small constant added to the denominator to improve numerical stability.
|
||||||
|
Defaults to 1e-5.
|
||||||
|
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||||
|
create_scale (bool, optional): Whether to include a learnable scaling factor. Defaults to True.
|
||||||
|
sharding (Optional[P], optional): The sharding specification for the scale parameter. Defaults to None.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
axis: Union[int, Sequence[int], slice],
|
axis: Union[int, Sequence[int], slice],
|
||||||
@ -633,12 +831,19 @@ def rotate_half(
|
|||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(hk.Module):
|
class RotaryEmbedding(hk.Module):
|
||||||
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
"""
|
||||||
|
Implements Rotary Position Embedding (RoPE) to the input sequence tensor,
|
||||||
as described in https://arxiv.org/abs/2104.09864.
|
as described in https://arxiv.org/abs/2104.09864.
|
||||||
|
|
||||||
Attributes:
|
RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their
|
||||||
dim (int): Dimensionality of the feature vectors
|
position in the sequence. This approach is designed to preserve the relative positional information across
|
||||||
base_exponent (int): Base exponent to compute embeddings from
|
different sequence lengths and tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimensionality of the embeddings to be rotated, must be even.
|
||||||
|
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||||
|
base_exponent (int, optional): The base of the exponent used to calculate rotary frequencies.
|
||||||
|
Defaults to 10000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -726,6 +931,22 @@ class MultiHeadAttention(hk.Module):
|
|||||||
kv_memory: Optional[KVMemory] = None,
|
kv_memory: Optional[KVMemory] = None,
|
||||||
mesh: Any = None,
|
mesh: Any = None,
|
||||||
) -> MHAOutput:
|
) -> MHAOutput:
|
||||||
|
"""
|
||||||
|
Computes the multi-head attention over the input queries, keys, and values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (jax.Array): Query vectors, shaped as [batch_size, seq_length, model_dim].
|
||||||
|
key (Optional[jax.Array]): Key vectors. If None, uses query as key.
|
||||||
|
value (Optional[jax.Array]): Value vectors. If None, uses query as value.
|
||||||
|
mask (Optional[jax.Array]): An optional mask to prevent attention to certain positions,
|
||||||
|
shaped as [batch_size, 1, seq_length, seq_length].
|
||||||
|
kv_memory (Optional[KVMemory]): Optional memory for keys and values to support efficient
|
||||||
|
autoregressive decoding.
|
||||||
|
mesh (Any): The SPMD mesh for parallel computation, if applicable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MHAOutput: A named tuple containing the output embeddings and updated memory.
|
||||||
|
"""
|
||||||
# In shape hints below, we suppress the leading dims [...] for brevity.
|
# In shape hints below, we suppress the leading dims [...] for brevity.
|
||||||
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
||||||
sequence_length = query.shape[1]
|
sequence_length = query.shape[1]
|
||||||
@ -1034,7 +1255,25 @@ class DecoderLayer(hk.Module):
|
|||||||
padding_mask: Optional[jax.Array],
|
padding_mask: Optional[jax.Array],
|
||||||
layer_memory: Optional[KVMemory],
|
layer_memory: Optional[KVMemory],
|
||||||
) -> DecoderOutput:
|
) -> DecoderOutput:
|
||||||
"""Transforms input embedding sequences to output embedding sequences."""
|
"""
|
||||||
|
Transforms input embedding sequences to output embedding sequences.
|
||||||
|
Processes input embeddings through a single layer of the decoder.
|
||||||
|
|
||||||
|
This method applies multi-head attention followed by position-wise feed-forward networks,
|
||||||
|
including any necessary normalization and skip connections, as per the transformer architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (jax.Array): Input embeddings, shaped [batch_size, seq_length, model_dim].
|
||||||
|
mask (jax.Array): Attention mask, shaped [batch_size, 1, seq_length, seq_length], used to prevent
|
||||||
|
attention to future positions.
|
||||||
|
padding_mask (Optional[jax.Array]): Mask indicating which positions are padding tokens,
|
||||||
|
to exclude them from attention calculations.
|
||||||
|
layer_memory (Optional[KVMemory]): Memory state for storing past key/value pairs for efficient
|
||||||
|
autoregressive decoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DecoderOutput: Named tuple containing output embeddings and updated memory state.
|
||||||
|
"""
|
||||||
|
|
||||||
def layer_norm(x):
|
def layer_norm(x):
|
||||||
return hk_rms_norm(x)
|
return hk_rms_norm(x)
|
||||||
@ -1145,7 +1384,25 @@ class InOutEmbed(hk.Embed):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LanguageModelConfig:
|
class LanguageModelConfig:
|
||||||
"""An autoregressive transformer-based language model."""
|
"""
|
||||||
|
Configuration class for an autoregressive language model based on the Transformer architecture.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (TransformerConfig): The transformer model configuration.
|
||||||
|
vocab_size (int): The size of the vocabulary.
|
||||||
|
pad_token (int): The token used for padding sequences.
|
||||||
|
eos_token (int): The end-of-sentence token.
|
||||||
|
sequence_len (int): The maximum sequence length the model can handle.
|
||||||
|
model_size (int): The dimensionality of the model embeddings.
|
||||||
|
embedding_init_scale (float): Initial scale for embedding parameter initialization.
|
||||||
|
embedding_multiplier_scale (float): Multiplier for scaling the embedding vectors.
|
||||||
|
output_multiplier_scale (float): Multiplier for scaling the output logits.
|
||||||
|
name (Optional[str]): Name of the language model configuration.
|
||||||
|
fprop_dtype (Any): Data type for forward propagation computations.
|
||||||
|
model_type (Optional[str]): Type of the model, if applicable.
|
||||||
|
init_scale_override (Optional[float]): Override for the initial scale of parameters, if needed.
|
||||||
|
shard_embeddings (bool): Whether to shard embeddings across the specified axes.
|
||||||
|
"""
|
||||||
|
|
||||||
model: Optional[TransformerConfig]
|
model: Optional[TransformerConfig]
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
@ -1200,7 +1457,20 @@ def layer_norm(x, model):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LanguageModel(hk.Module):
|
class LanguageModel(hk.Module):
|
||||||
"""An autoregressive transformer-based language model."""
|
"""
|
||||||
|
A transformer-based language model for generating or evaluating sequences of tokens.
|
||||||
|
|
||||||
|
This class encapsulates a transformer model and provides methods for its initialization,
|
||||||
|
running the model forward to generate logits, and handling memory states for efficient
|
||||||
|
autoregressive generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (Transformer): The underlying transformer model.
|
||||||
|
config (LanguageModelConfig): Configuration for the language model.
|
||||||
|
fprop_dtype (Any): Data type for forward propagation computations.
|
||||||
|
name (Optional[str]): Optional name for the module.
|
||||||
|
mesh (Any): The SPMD mesh for parallel computation, if applicable.
|
||||||
|
"""
|
||||||
|
|
||||||
model: "Transformer"
|
model: "Transformer"
|
||||||
config: LanguageModelConfig
|
config: LanguageModelConfig
|
||||||
@ -1217,7 +1487,22 @@ class LanguageModel(hk.Module):
|
|||||||
last_hid_only: bool = False,
|
last_hid_only: bool = False,
|
||||||
length: Optional[jax.Array] = None,
|
length: Optional[jax.Array] = None,
|
||||||
) -> LanguageModelOutput:
|
) -> LanguageModelOutput:
|
||||||
"""Forward pass, producing a sequence of logits."""
|
"""
|
||||||
|
Forward pass, producing a sequence of logits.
|
||||||
|
Generates logits for the next token predictions based on input tokens and optional memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (jax.Array): Input tokens to the language model, shaped as [batch_size, seq_length].
|
||||||
|
memory (Optional[Memory]): Optional memory state from previous steps, for autoregressive generation.
|
||||||
|
batch (Dict[str, jax.Array]): Additional batch information, unused here.
|
||||||
|
last_hid_only (bool): If True, returns only the last hidden state instead of logits.
|
||||||
|
length (Optional[jax.Array]): Specifies the length of each sequence in the batch for processing
|
||||||
|
only up to those lengths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LanguageModelOutput: A named tuple containing the logits for next token predictions and the
|
||||||
|
updated memory state.
|
||||||
|
"""
|
||||||
del batch # Unused.
|
del batch # Unused.
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
@ -1290,7 +1575,30 @@ class LanguageModel(hk.Module):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transformer(hk.Module):
|
class Transformer(hk.Module):
|
||||||
"""A transformer stack."""
|
"""
|
||||||
|
Core transformer model class implementing a stack of transformer layers.
|
||||||
|
|
||||||
|
This class is designed to be flexible and configurable, supporting features like
|
||||||
|
multi-head attention, feed-forward networks, and optional mixture of experts layers.
|
||||||
|
It is capable of processing sequences of embeddings and returning transformed sequences
|
||||||
|
of embeddings, along with updated memory states for autoregressive tasks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_q_heads (int): Number of query heads for multi-head attention.
|
||||||
|
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||||
|
key_size (int): Size of keys in the attention mechanism.
|
||||||
|
widening_factor (float): Widening factor for the dimensionality of the feed-forward network.
|
||||||
|
init_scale (float): Scale for parameter initialization.
|
||||||
|
mesh (Any): The SPMD mesh for parallel computation.
|
||||||
|
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||||
|
shard_activations (bool): Whether to shard activations across specified axes.
|
||||||
|
num_layers (int): Number of transformer layers.
|
||||||
|
num_experts (int): Number of experts for mixture of experts layers, if applicable.
|
||||||
|
num_selected_experts (int): Number of experts selected per token, for MoE layers.
|
||||||
|
name (Optional[str]): Name of the transformer model.
|
||||||
|
data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding.
|
||||||
|
model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding.
|
||||||
|
"""
|
||||||
|
|
||||||
num_q_heads: int
|
num_q_heads: int
|
||||||
num_kv_heads: int
|
num_kv_heads: int
|
||||||
@ -1311,6 +1619,20 @@ class Transformer(hk.Module):
|
|||||||
model_axis: Union[str, Tuple[str, ...]] = "model"
|
model_axis: Union[str, Tuple[str, ...]] = "model"
|
||||||
|
|
||||||
def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):
|
def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):
|
||||||
|
"""
|
||||||
|
Initializes the memory state for the transformer model.
|
||||||
|
|
||||||
|
This is particularly useful for autoregressive tasks where past key and value pairs are cached
|
||||||
|
to improve efficiency in generating sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): The batch size for which to initialize memory states.
|
||||||
|
sequence_len (int): The sequence length for initializing the size of memory buffers.
|
||||||
|
dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory: A named tuple representing the initialized memory state for each layer.
|
||||||
|
"""
|
||||||
return Memory(
|
return Memory(
|
||||||
layers=init_layer_memories(
|
layers=init_layer_memories(
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -1329,7 +1651,22 @@ class Transformer(hk.Module):
|
|||||||
mask: jax.Array, # [B, T]
|
mask: jax.Array, # [B, T]
|
||||||
memory: Optional[Memory],
|
memory: Optional[Memory],
|
||||||
) -> TransformerOutput:
|
) -> TransformerOutput:
|
||||||
"""Transforms input embedding sequences to output embedding sequences."""
|
"""
|
||||||
|
Processes input embeddings through the transformer model.
|
||||||
|
Transforms input embedding sequences to output embedding sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings (jax.Array): Input embeddings to be processed by the transformer, shaped as
|
||||||
|
[batch_size, seq_length, model_dim].
|
||||||
|
mask (jax.Array): Mask indicating valid positions within the input, to control which positions
|
||||||
|
are allowed to attend to each other, shaped as [batch_size, seq_length].
|
||||||
|
memory (Optional[Memory]): Optional memory state for the transformer to support autoregressive
|
||||||
|
decoding or similar use cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformerOutput: A named tuple containing the transformed embeddings and the final state
|
||||||
|
of the memory after processing.
|
||||||
|
"""
|
||||||
|
|
||||||
fprop_dtype = embeddings.dtype
|
fprop_dtype = embeddings.dtype
|
||||||
_, seq_len, model_size = embeddings.shape
|
_, seq_len, model_size = embeddings.shape
|
||||||
|
Loading…
Reference in New Issue
Block a user