mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-27 05:59:52 +03:00
Added docstrings to multiple modules and methods
This commit is contained in:
parent
d6d9447e2d
commit
429a83e5d9
363
model.py
363
model.py
@ -36,6 +36,13 @@ rank_logger = logging.getLogger("rank")
|
||||
|
||||
@dataclass
|
||||
class QuantizedWeight8bit:
|
||||
"""
|
||||
Represents an 8-bit quantized weight for neural network parameters.
|
||||
|
||||
Attributes:
|
||||
weight (jnp.array): The quantized weights.
|
||||
scales (jnp.array): The scale factors used for quantization.
|
||||
"""
|
||||
weight: jnp.array
|
||||
scales: jnp.array
|
||||
|
||||
@ -52,13 +59,26 @@ tree_util.register_pytree_node(
|
||||
|
||||
|
||||
class TrainingState(NamedTuple):
|
||||
"""Container for the training state."""
|
||||
"""Container for the training state, encapsulating model parameters.
|
||||
|
||||
Attributes:
|
||||
params (hk.Params): The parameters of the model.
|
||||
"""
|
||||
|
||||
params: hk.Params
|
||||
|
||||
|
||||
def _match(qs, ks):
|
||||
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
||||
"""
|
||||
Checks if regex patterns in qs match any contiguous subsequence of strings in ks.
|
||||
|
||||
Args:
|
||||
qs (Sequence[str]): A sequence of regex patterns to match.
|
||||
ks (Tuple[str, ...]): A tuple of strings against which the patterns are matched.
|
||||
|
||||
Returns:
|
||||
bool: True if there's a match for all patterns in a contiguous subsequence of ks, False otherwise.
|
||||
"""
|
||||
# compile regexes and force complete match
|
||||
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
||||
for i in range(len(ks) - len(qs) + 1):
|
||||
@ -69,6 +89,16 @@ def _match(qs, ks):
|
||||
|
||||
|
||||
def with_sharding_constraint(x, constraint):
|
||||
"""
|
||||
Applies a sharding constraint to a JAX array, if a physical mesh is available.
|
||||
|
||||
Args:
|
||||
x (jax.Array): The array to apply the sharding constraint to.
|
||||
constraint (PartitionSpec): The sharding constraint to apply.
|
||||
|
||||
Returns:
|
||||
jax.Array: The array with the sharding constraint applied if a physical mesh is present.
|
||||
"""
|
||||
if jax.experimental.maps.thread_resources.env.physical_mesh.empty:
|
||||
return x
|
||||
else:
|
||||
@ -76,6 +106,15 @@ def with_sharding_constraint(x, constraint):
|
||||
|
||||
|
||||
def cast_bfloat16(x):
|
||||
"""
|
||||
Casts the input to bfloat16 type if it is of floating-point type.
|
||||
|
||||
Args:
|
||||
x (jax.Array): The input array.
|
||||
|
||||
Returns:
|
||||
jax.Array: The input array casted to bfloat16 if it was floating-point, unchanged otherwise.
|
||||
"""
|
||||
if x.dtype.kind == "f":
|
||||
return x.astype(jnp.bfloat16)
|
||||
else:
|
||||
@ -83,6 +122,18 @@ def cast_bfloat16(x):
|
||||
|
||||
|
||||
def ffn_size(emb_size, widening_factor):
|
||||
"""
|
||||
Calculates the size of the feed-forward network (FFN) based on the embedding size and a widening factor.
|
||||
|
||||
The calculated FFN size is adjusted to be a multiple of 8 for efficiency in hardware implementations.
|
||||
|
||||
Args:
|
||||
emb_size (int): The size of the embeddings.
|
||||
widening_factor (float): The factor by which to widen the FFN relative to the embedding size.
|
||||
|
||||
Returns:
|
||||
int: The adjusted size of the FFN.
|
||||
"""
|
||||
_ffn_size = int(widening_factor * emb_size) * 2 // 3
|
||||
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
|
||||
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
|
||||
@ -90,6 +141,20 @@ def ffn_size(emb_size, widening_factor):
|
||||
|
||||
|
||||
def apply_rules(rules):
|
||||
"""
|
||||
Constructs a function to apply a set of sharding rules for transformer parameters.
|
||||
|
||||
This function is used to determine the sharding specifications for model parameters based on their roles
|
||||
and positions within the model architecture.
|
||||
|
||||
Args:
|
||||
rules (List[Tuple[Sequence[str], PartitionSpec]]): A list of tuples where each tuple contains a sequence
|
||||
of strings representing the parameter path and the corresponding `PartitionSpec` to apply.
|
||||
|
||||
Returns:
|
||||
Callable: A function that takes a parameter path and returns the appropriate `PartitionSpec` based
|
||||
on the provided rules.
|
||||
"""
|
||||
def _apply_rules(path, value):
|
||||
del value # Unused.
|
||||
|
||||
@ -190,6 +255,21 @@ def init_layer_memories(
|
||||
step: Optional[jax.Array] = None,
|
||||
dtype=jnp.bfloat16,
|
||||
):
|
||||
"""
|
||||
Initializes memory slots for each layer in the transformer model.
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batch.
|
||||
sequence_len (int): The length of the sequence.
|
||||
num_kv_heads (int): The number of key/value pairs per head.
|
||||
key_size (int): The size of each key.
|
||||
num_layers (int): The number of layers in the transformer.
|
||||
step (Optional[jax.Array]): The initial step for the memory, defaults to None.
|
||||
dtype (Any): The data type of the memory arrays, defaults to jnp.bfloat16.
|
||||
|
||||
Returns:
|
||||
List[KVMemory]: A list of initialized KVMemory instances for each layer.
|
||||
"""
|
||||
return [
|
||||
KVMemory(
|
||||
k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),
|
||||
@ -206,6 +286,17 @@ class Memory(NamedTuple):
|
||||
|
||||
|
||||
class Router(hk.Module):
|
||||
"""
|
||||
A module for routing inputs to experts in a Mixture of Experts (MoE) layer.
|
||||
|
||||
Attributes:
|
||||
num_selected_experts (int): Number of experts to select for each input.
|
||||
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
|
||||
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
|
||||
shard_activations (bool): If True, shard activations according to the data and model axes.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
name (str): The name of the module.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_selected_experts: int,
|
||||
@ -234,6 +325,19 @@ class Router(hk.Module):
|
||||
padding_mask: Optional[jax.Array],
|
||||
num_experts: int,
|
||||
):
|
||||
"""
|
||||
Computes the routing probabilities for directing inputs to the appropriate experts.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input data to be routed, shaped as [batch_size, ..., input_dim].
|
||||
padding_mask (Optional[jax.Array]): An optional mask indicating padded elements in the input,
|
||||
shaped as [batch_size, seq_length], where padded positions are False.
|
||||
num_experts (int): The total number of experts available for routing.
|
||||
|
||||
Returns:
|
||||
A tuple containing routing probabilities, routing logits, and a dummy value for compatibility,
|
||||
shaped as ([batch_size, seq_length, num_experts], [batch_size, seq_length, num_experts], int).
|
||||
"""
|
||||
# Using fp32 for the routing prob computation.
|
||||
inputs = jax.lax.convert_element_type(inputs, jnp.float32)
|
||||
|
||||
@ -270,6 +374,19 @@ class Router(hk.Module):
|
||||
|
||||
|
||||
class MoELayer(hk.Module):
|
||||
"""
|
||||
A module implementing a Mixture of Experts (MoE) layer.
|
||||
|
||||
Attributes:
|
||||
num_experts (int): The number of experts in the MoE layer.
|
||||
layer_fn (Callable): The function to be applied by each expert.
|
||||
router (Router): The router that routes inputs to experts.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
shard_activations (bool): If True, shard activations across data and model axes.
|
||||
data_axis (str | Tuple[str, ...]): The name(s) of the data axis for sharding.
|
||||
model_axis (str | Tuple[str, ...]): The name(s) of the model axis for sharding.
|
||||
name (Optional[str]): The name of the module.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
@ -292,6 +409,18 @@ class MoELayer(hk.Module):
|
||||
|
||||
@hk.transparent
|
||||
def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
|
||||
"""
|
||||
Handles the inference call to the MoE layer, distributing inputs to selected experts based on routing.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input data to be processed, shaped as [batch_size, seq_length, input_dim].
|
||||
padding_mask (Optional[jax.Array]): An optional mask for the inputs, where False indicates
|
||||
positions that should not be processed (e.g., padding), shaped as [batch_size, seq_length].
|
||||
|
||||
Returns:
|
||||
jax.Array: The processed outputs after passing through the selected experts, shaped as
|
||||
[batch_size, seq_length, output_dim].
|
||||
"""
|
||||
routing_probs, _, _ = self.router.compute_routing_prob(
|
||||
inputs, padding_mask, self.num_experts
|
||||
)
|
||||
@ -401,24 +530,65 @@ class MoELayer(hk.Module):
|
||||
|
||||
|
||||
class MHAOutput(NamedTuple):
|
||||
"""Outputs of the multi-head attention operation."""
|
||||
"""
|
||||
Represents the output of the Multi-Head Attention (MHA) operation.
|
||||
|
||||
Attributes:
|
||||
embeddings (jax.Array): The output embeddings from the MHA layer.
|
||||
memory (Any): The updated memory state post-attention operation.
|
||||
"""
|
||||
|
||||
embeddings: jax.Array
|
||||
memory: Any
|
||||
|
||||
|
||||
class DecoderOutput(NamedTuple):
|
||||
"""
|
||||
Encapsulates the output of a decoder layer within the transformer model.
|
||||
|
||||
Attributes:
|
||||
embeddings (jax.Array): The embeddings produced by the decoder layer.
|
||||
memory (Any): The updated memory state after processing by the decoder layer.
|
||||
"""
|
||||
embeddings: jax.Array
|
||||
memory: Any
|
||||
|
||||
|
||||
class TransformerOutput(NamedTuple):
|
||||
"""
|
||||
Represents the final output of the transformer model.
|
||||
|
||||
Attributes:
|
||||
embeddings (jax.Array): The final output embeddings from the transformer.
|
||||
memory (Any): The final state of the memory after passing through the transformer layers.
|
||||
"""
|
||||
embeddings: jax.Array
|
||||
memory: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
"""
|
||||
Configuration class for a Transformer model specifying the model's architecture and settings.
|
||||
|
||||
Attributes:
|
||||
emb_size (int): Embedding size used in the transformer.
|
||||
key_size (int): Size of the keys in the multi-head attention mechanism.
|
||||
num_q_heads (int): Number of query heads in the multi-head attention.
|
||||
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||
num_layers (int): Number of layers in the transformer model.
|
||||
vocab_size (int): The size of the vocabulary.
|
||||
widening_factor (float): Factor to widen the feedforward network dimension relative to emb_size.
|
||||
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||
name (Optional[str]): Name of the transformer configuration.
|
||||
num_experts (int): Number of experts in a mixture of experts layer.
|
||||
capacity_factor (float): Capacity factor for routing in MoE layers.
|
||||
num_selected_experts (int): Number of experts selected in each MoE layer.
|
||||
init_scale (float): Initial scale for parameter initialization.
|
||||
shard_activations (bool): If True, activations will be sharded across the specified axes.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names over which data is sharded.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names over which model parameters are sharded.
|
||||
"""
|
||||
emb_size: int
|
||||
key_size: int
|
||||
num_q_heads: int
|
||||
@ -523,6 +693,20 @@ def make_attention_mask(
|
||||
|
||||
|
||||
class Linear(hk.Linear):
|
||||
"""
|
||||
Extends Haiku's Linear layer with optional sharding for use in distributed settings.
|
||||
|
||||
This class allows specifying a `PartitionSpec` to shard the linear layer's weights across devices,
|
||||
which can be beneficial in large-scale models processed over multiple devices or nodes.
|
||||
|
||||
Args:
|
||||
output_size (int): The size of the output dimension.
|
||||
with_bias (bool, optional): Whether to include a bias term. Defaults to True.
|
||||
sharding (Optional[P], optional): The sharding specification for distributing the layer's parameters.
|
||||
mesh (Any, optional): The SPMD mesh for parallel computation. Defaults to None.
|
||||
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||
shard_axis (int, optional): The axis along which to shard the input data. Defaults to 0.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
output_size: int,
|
||||
@ -585,7 +769,21 @@ class Linear(hk.Linear):
|
||||
|
||||
|
||||
class RMSNorm(hk.RMSNorm):
|
||||
"""
|
||||
Implements Root Mean Square Layer Normalization.
|
||||
|
||||
This variant of layer normalization scales inputs by the root mean square of their elements, optionally
|
||||
including a learnable scaling factor. It supports specifying a `PartitionSpec` for sharding the scale
|
||||
parameters across devices in distributed settings.
|
||||
|
||||
Args:
|
||||
axis (Union[int, Sequence[int], slice]): The dimensions to normalize over.
|
||||
eps (float, optional): A small constant added to the denominator to improve numerical stability.
|
||||
Defaults to 1e-5.
|
||||
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||
create_scale (bool, optional): Whether to include a learnable scaling factor. Defaults to True.
|
||||
sharding (Optional[P], optional): The sharding specification for the scale parameter. Defaults to None.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
axis: Union[int, Sequence[int], slice],
|
||||
@ -633,12 +831,19 @@ def rotate_half(
|
||||
|
||||
|
||||
class RotaryEmbedding(hk.Module):
|
||||
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
||||
"""
|
||||
Implements Rotary Position Embedding (RoPE) to the input sequence tensor,
|
||||
as described in https://arxiv.org/abs/2104.09864.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of the feature vectors
|
||||
base_exponent (int): Base exponent to compute embeddings from
|
||||
RoPE encodes positional information dynamically by applying a rotation to the input embeddings based on their
|
||||
position in the sequence. This approach is designed to preserve the relative positional information across
|
||||
different sequence lengths and tasks.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the embeddings to be rotated, must be even.
|
||||
name (Optional[str], optional): An optional name for this module. Defaults to None.
|
||||
base_exponent (int, optional): The base of the exponent used to calculate rotary frequencies.
|
||||
Defaults to 10000.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -726,6 +931,22 @@ class MultiHeadAttention(hk.Module):
|
||||
kv_memory: Optional[KVMemory] = None,
|
||||
mesh: Any = None,
|
||||
) -> MHAOutput:
|
||||
"""
|
||||
Computes the multi-head attention over the input queries, keys, and values.
|
||||
|
||||
Args:
|
||||
query (jax.Array): Query vectors, shaped as [batch_size, seq_length, model_dim].
|
||||
key (Optional[jax.Array]): Key vectors. If None, uses query as key.
|
||||
value (Optional[jax.Array]): Value vectors. If None, uses query as value.
|
||||
mask (Optional[jax.Array]): An optional mask to prevent attention to certain positions,
|
||||
shaped as [batch_size, 1, seq_length, seq_length].
|
||||
kv_memory (Optional[KVMemory]): Optional memory for keys and values to support efficient
|
||||
autoregressive decoding.
|
||||
mesh (Any): The SPMD mesh for parallel computation, if applicable.
|
||||
|
||||
Returns:
|
||||
MHAOutput: A named tuple containing the output embeddings and updated memory.
|
||||
"""
|
||||
# In shape hints below, we suppress the leading dims [...] for brevity.
|
||||
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
||||
sequence_length = query.shape[1]
|
||||
@ -1034,7 +1255,25 @@ class DecoderLayer(hk.Module):
|
||||
padding_mask: Optional[jax.Array],
|
||||
layer_memory: Optional[KVMemory],
|
||||
) -> DecoderOutput:
|
||||
"""Transforms input embedding sequences to output embedding sequences."""
|
||||
"""
|
||||
Transforms input embedding sequences to output embedding sequences.
|
||||
Processes input embeddings through a single layer of the decoder.
|
||||
|
||||
This method applies multi-head attention followed by position-wise feed-forward networks,
|
||||
including any necessary normalization and skip connections, as per the transformer architecture.
|
||||
|
||||
Args:
|
||||
inputs (jax.Array): Input embeddings, shaped [batch_size, seq_length, model_dim].
|
||||
mask (jax.Array): Attention mask, shaped [batch_size, 1, seq_length, seq_length], used to prevent
|
||||
attention to future positions.
|
||||
padding_mask (Optional[jax.Array]): Mask indicating which positions are padding tokens,
|
||||
to exclude them from attention calculations.
|
||||
layer_memory (Optional[KVMemory]): Memory state for storing past key/value pairs for efficient
|
||||
autoregressive decoding.
|
||||
|
||||
Returns:
|
||||
DecoderOutput: Named tuple containing output embeddings and updated memory state.
|
||||
"""
|
||||
|
||||
def layer_norm(x):
|
||||
return hk_rms_norm(x)
|
||||
@ -1145,7 +1384,25 @@ class InOutEmbed(hk.Embed):
|
||||
|
||||
@dataclass
|
||||
class LanguageModelConfig:
|
||||
"""An autoregressive transformer-based language model."""
|
||||
"""
|
||||
Configuration class for an autoregressive language model based on the Transformer architecture.
|
||||
|
||||
Attributes:
|
||||
model (TransformerConfig): The transformer model configuration.
|
||||
vocab_size (int): The size of the vocabulary.
|
||||
pad_token (int): The token used for padding sequences.
|
||||
eos_token (int): The end-of-sentence token.
|
||||
sequence_len (int): The maximum sequence length the model can handle.
|
||||
model_size (int): The dimensionality of the model embeddings.
|
||||
embedding_init_scale (float): Initial scale for embedding parameter initialization.
|
||||
embedding_multiplier_scale (float): Multiplier for scaling the embedding vectors.
|
||||
output_multiplier_scale (float): Multiplier for scaling the output logits.
|
||||
name (Optional[str]): Name of the language model configuration.
|
||||
fprop_dtype (Any): Data type for forward propagation computations.
|
||||
model_type (Optional[str]): Type of the model, if applicable.
|
||||
init_scale_override (Optional[float]): Override for the initial scale of parameters, if needed.
|
||||
shard_embeddings (bool): Whether to shard embeddings across the specified axes.
|
||||
"""
|
||||
|
||||
model: Optional[TransformerConfig]
|
||||
vocab_size: int
|
||||
@ -1200,7 +1457,20 @@ def layer_norm(x, model):
|
||||
|
||||
@dataclass
|
||||
class LanguageModel(hk.Module):
|
||||
"""An autoregressive transformer-based language model."""
|
||||
"""
|
||||
A transformer-based language model for generating or evaluating sequences of tokens.
|
||||
|
||||
This class encapsulates a transformer model and provides methods for its initialization,
|
||||
running the model forward to generate logits, and handling memory states for efficient
|
||||
autoregressive generation.
|
||||
|
||||
Attributes:
|
||||
model (Transformer): The underlying transformer model.
|
||||
config (LanguageModelConfig): Configuration for the language model.
|
||||
fprop_dtype (Any): Data type for forward propagation computations.
|
||||
name (Optional[str]): Optional name for the module.
|
||||
mesh (Any): The SPMD mesh for parallel computation, if applicable.
|
||||
"""
|
||||
|
||||
model: "Transformer"
|
||||
config: LanguageModelConfig
|
||||
@ -1217,7 +1487,22 @@ class LanguageModel(hk.Module):
|
||||
last_hid_only: bool = False,
|
||||
length: Optional[jax.Array] = None,
|
||||
) -> LanguageModelOutput:
|
||||
"""Forward pass, producing a sequence of logits."""
|
||||
"""
|
||||
Forward pass, producing a sequence of logits.
|
||||
Generates logits for the next token predictions based on input tokens and optional memory state.
|
||||
|
||||
Args:
|
||||
tokens (jax.Array): Input tokens to the language model, shaped as [batch_size, seq_length].
|
||||
memory (Optional[Memory]): Optional memory state from previous steps, for autoregressive generation.
|
||||
batch (Dict[str, jax.Array]): Additional batch information, unused here.
|
||||
last_hid_only (bool): If True, returns only the last hidden state instead of logits.
|
||||
length (Optional[jax.Array]): Specifies the length of each sequence in the batch for processing
|
||||
only up to those lengths.
|
||||
|
||||
Returns:
|
||||
LanguageModelOutput: A named tuple containing the logits for next token predictions and the
|
||||
updated memory state.
|
||||
"""
|
||||
del batch # Unused.
|
||||
|
||||
config = self.config
|
||||
@ -1290,7 +1575,30 @@ class LanguageModel(hk.Module):
|
||||
|
||||
@dataclass
|
||||
class Transformer(hk.Module):
|
||||
"""A transformer stack."""
|
||||
"""
|
||||
Core transformer model class implementing a stack of transformer layers.
|
||||
|
||||
This class is designed to be flexible and configurable, supporting features like
|
||||
multi-head attention, feed-forward networks, and optional mixture of experts layers.
|
||||
It is capable of processing sequences of embeddings and returning transformed sequences
|
||||
of embeddings, along with updated memory states for autoregressive tasks.
|
||||
|
||||
Attributes:
|
||||
num_q_heads (int): Number of query heads for multi-head attention.
|
||||
num_kv_heads (int): Number of key/value pairs per attention head.
|
||||
key_size (int): Size of keys in the attention mechanism.
|
||||
widening_factor (float): Widening factor for the dimensionality of the feed-forward network.
|
||||
init_scale (float): Scale for parameter initialization.
|
||||
mesh (Any): The SPMD mesh for parallel computation.
|
||||
attn_output_multiplier (float): Multiplier for the output of the attention mechanism.
|
||||
shard_activations (bool): Whether to shard activations across specified axes.
|
||||
num_layers (int): Number of transformer layers.
|
||||
num_experts (int): Number of experts for mixture of experts layers, if applicable.
|
||||
num_selected_experts (int): Number of experts selected per token, for MoE layers.
|
||||
name (Optional[str]): Name of the transformer model.
|
||||
data_axis (Union[str, Tuple[str, ...]]): Axis names for data sharding.
|
||||
model_axis (Union[str, Tuple[str, ...]]): Axis names for model parameter sharding.
|
||||
"""
|
||||
|
||||
num_q_heads: int
|
||||
num_kv_heads: int
|
||||
@ -1311,6 +1619,20 @@ class Transformer(hk.Module):
|
||||
model_axis: Union[str, Tuple[str, ...]] = "model"
|
||||
|
||||
def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):
|
||||
"""
|
||||
Initializes the memory state for the transformer model.
|
||||
|
||||
This is particularly useful for autoregressive tasks where past key and value pairs are cached
|
||||
to improve efficiency in generating sequences.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size for which to initialize memory states.
|
||||
sequence_len (int): The sequence length for initializing the size of memory buffers.
|
||||
dtype (Any): The data type for the memory arrays, typically jnp.bfloat16 for efficiency.
|
||||
|
||||
Returns:
|
||||
Memory: A named tuple representing the initialized memory state for each layer.
|
||||
"""
|
||||
return Memory(
|
||||
layers=init_layer_memories(
|
||||
batch_size,
|
||||
@ -1329,7 +1651,22 @@ class Transformer(hk.Module):
|
||||
mask: jax.Array, # [B, T]
|
||||
memory: Optional[Memory],
|
||||
) -> TransformerOutput:
|
||||
"""Transforms input embedding sequences to output embedding sequences."""
|
||||
"""
|
||||
Processes input embeddings through the transformer model.
|
||||
Transforms input embedding sequences to output embedding sequences.
|
||||
|
||||
Args:
|
||||
embeddings (jax.Array): Input embeddings to be processed by the transformer, shaped as
|
||||
[batch_size, seq_length, model_dim].
|
||||
mask (jax.Array): Mask indicating valid positions within the input, to control which positions
|
||||
are allowed to attend to each other, shaped as [batch_size, seq_length].
|
||||
memory (Optional[Memory]): Optional memory state for the transformer to support autoregressive
|
||||
decoding or similar use cases.
|
||||
|
||||
Returns:
|
||||
TransformerOutput: A named tuple containing the transformed embeddings and the final state
|
||||
of the memory after processing.
|
||||
"""
|
||||
|
||||
fprop_dtype = embeddings.dtype
|
||||
_, seq_len, model_size = embeddings.shape
|
||||
|
Loading…
Reference in New Issue
Block a user