Update readme and add break downs for each class

This commit is contained in:
pouya samie 2024-03-18 13:42:09 +03:30
parent 6eef3b537f
commit 6f7832706b

106
README.md
View File

@ -33,104 +33,108 @@ repository and the model weights of Grok-1.
# Table of contents # Table of contents
This repository contains Python code for grok-1. Below is a breakdown of the main components and features provided by this codebase: This repository contains Python code for grok-1. Below is a breakdown of the main components and features provided by this codebase:
# :book: **model.py** # :book: model.py
i will try to breakdown each classes in detail i will try to breakdown each classes in detail
## :page_with_curl: functions, varibales and constants ## :page_with_curl: functions, varibales and constants
**ffn_size**: This function computes the size (number of units) for the feed-forward network (FFN) layer in the transformer architecture. The FFN size is typically larger than the embedding size to increase the model's expressive power. The function takes two arguments: - `ffn_size`: This function computes the size (number of units) for the feed-forward network (FFN) layer in the transformer architecture. The FFN size is typically larger than the embedding size to increase the model's expressive power. The function takes two arguments:
**emb_size**: The size of the input embeddings. - `emb_size`: The size of the input embeddings.
**widening_factor**: A multiplier used to determine the FFN size relative to the embedding size. - `widening_factor`: A multiplier used to determine the FFN size relative to the embedding size.
The function first calculates the `FFN size` as `int(widening_factor * emb_size) * 2 // 3`. The * 2 // 3 part is a heuristic to reduce the `FFN size` slightly while maintaining good performance. Then, it ensures that the `FFN size` is a multiple of 8 by adding the smallest positive number needed to make it divisible by 8. This is likely done for efficient computations on certain hardware architectures. The function first calculates the `FFN size` as `int(widening_factor * emb_size) * 2 // 3`. The * 2 // 3 part is a heuristic to reduce the `FFN size` slightly while maintaining good performance. Then, it ensures that the `FFN size` is a multiple of 8 by adding the smallest positive number needed to make it divisible by 8. This is likely done for efficient computations on certain hardware architectures.
**apply_rules**: This function returns a closure (inner function) that applies a set of rules to reshape (reshard) the parameters of a neural network model based on regular expression patterns. The rules are provided as a list of tuples, where each tuple contains a list of regular expression patterns and a corresponding reshape specification (PartitionSpec). - `apply_rules`: This function returns a closure (inner function) that applies a set of rules to reshape (reshard) the parameters of a neural network model based on regular expression patterns. The rules are provided as a list of tuples, where each tuple contains a list of regular expression patterns and a corresponding reshape specification (PartitionSpec).
The inner function _apply_rules takes a path (list of keys) and a value (tensor) as input. It flattens the path and checks if any of the provided rules (regular expressions) match the flattened path. If a match is found, the corresponding PartitionSpec is applied to the tensor, effectively reshaping it according to the specified partitioning scheme. This function is likely used to optimize the model for distributed training across multiple devices or accelerators. The inner function `_apply_rules` takes a path (list of keys) and a value (tensor) as input. It flattens the path and checks if any of the provided rules (regular expressions) match the flattened path. If a match is found, the corresponding PartitionSpec is applied to the tensor, effectively reshaping it according to the specified partitioning scheme. This function is likely used to optimize the model for distributed training across multiple devices or accelerators.
**cast_bfloat16**: This is a simple utility function that casts the input tensor (x) to the bfloat16 data type if the tensor's data type is a floating-point type. The bfloat16 data type is a truncated 16-bit floating-point format that can provide higher computational performance on certain hardware architectures, such as Google's TPUs, while maintaining reasonable precision. If the input tensor is not a floating-point type, the function returns the tensor unchanged. - `cast_bfloat16`: This is a simple utility function that casts the input tensor (x) to the bfloat16 data type if the tensor's data type is a floating-point type. The bfloat16 data type is a truncated 16-bit floating-point format that can provide higher computational performance on certain hardware architectures, such as Google's TPUs, while maintaining reasonable precision. If the input tensor is not a floating-point type, the function returns the tensor unchanged.
**with_sharding_constraint**: This function applies a sharding constraint to the input tensor (x). Sharding is a technique used in distributed training to split the model parameters and computations across multiple devices or accelerators. The sharding constraint specifies how the tensor should be partitioned or reshaped for efficient parallel computations. The function takes two arguments: - `with_sharding_constraint`: This function applies a sharding constraint to the input tensor (x). Sharding is a technique used in distributed training to split the model parameters and computations across multiple devices or accelerators. The sharding constraint specifies how the tensor should be partitioned or reshaped for efficient parallel computations. The function takes two arguments:
x: The input tensor to be reshaped.
constraint: A PartitionSpec object that specifies the desired reshaping or partitioning scheme.
If the current environment does not support distributed training (i.e., jax.experimental.maps.thread_resources.env.physical_mesh.empty is True), the function returns the input tensor unchanged. Otherwise, it applies the specified sharding constraint to the tensor using the pjit_sharding_constraint function from JAX. - `x`: The input tensor to be reshaped.
- `constraint`: A PartitionSpec object that specifies the desired reshaping or partitioning scheme.
If the current environment does not support distributed training (i.e., jax.experimental.maps.thread_resources.env.physical_mesh.empty is True), the function returns the input tensor unchanged. Otherwise, it applies the specified sharding constraint to the tensor using the pjit_sharding_constraint function from JAX.
**match**: This is a helper function used by apply_rules. It takes two arguments: - `match`: This is a helper function used by `apply_rules`.
It takes two arguments:
`qs`: A tuple of regular expression patterns. `qs`: A tuple of regular expression patterns.
`ks`: A tuple of strings representing the flattened path. `ks`: A tuple of strings representing the flattened path.
The function compiles the regular expressions in qs and checks if any window (consecutive sublist) of strings in ks matches all the compiled regular expressions simultaneously. If a match is found, it returns True; otherwise, it returns False. This function is likely used by apply_rules to determine if a specific set of rules (regular expressions) should be applied to a given path in the neural network model. The function compiles the regular expressions in qs and checks if any window (consecutive sublist) of strings in ks matches all the compiled regular expressions simultaneously. If a match is found, it returns True; otherwise, it returns False. This function is likely used by apply_rules to determine if a specific set of rules (regular expressions) should be applied to a given path in the neural network model.
**init_layer_memories**: - `init_layer_memories`:
**hk_rms_norm**: - `hk_rms_norm`:
**make_attention_mask**: - `make_attention_mask`:
**rotate_half**: - `rotate_half`:
**layer_norm**: - `layer_norm`:
### TRANSFORMER_PARTITION_RULES: - ### `TRANSFORMER_PARTITION_RULES`:
`TRANSFORMER_PARTITION_RULES` is a list of tuples that define the partitioning rules for the parameters of a transformer model. These rules are used by the `apply_rules` function to reshape (reshard) the model parameters for efficient distributed training across multiple devices or accelerators. is a list of tuples that define the partitioning rules for the parameters of a transformer model. These rules are used by the `apply_rules` function to reshape (reshard) the model parameters for efficient distributed training across multiple devices or accelerators.
Each tuple in the TRANSFORMER_PARTITION_RULES list consists of two elements: Each tuple in the TRANSFORMER_PARTITION_RULES list consists of two elements:
1. A tuple of regular expression patterns that match the parameter names or paths in the model. 1. A tuple of regular expression patterns that match the parameter names or paths in the model.
2. A `PartitionSpec` object that specifies how the matched parameters should be partitioned or reshaped. 2. A `PartitionSpec` object that specifies how the matched parameters should be partitioned or reshaped.
let's dive in some of the partitioning rules defined in `TRANSFORMER_PARTITION_RULES`: let's dive in some of the partitioning rules defined in `TRANSFORMER_PARTITION_RULES`:
- #### `(("multi_head_attention", "(query|key|value)", "w"), P("data", "model"))`: - #### `(("multi_head_attention", "(query|key|value)", "w"), P("data", "model"))`:
This rule matches the weight tensors (w) of the query, key, and value projections in the multi-head attention module. It specifies that these weights should be partitioned along the "data" and "model" dimensions, which means they will be split across multiple devices or accelerators along those dimensions. This rule matches the weight tensors (w) of the query, key, and value projections in the multi-head attention module. It specifies that these weights should be partitioned along the "data" and "model" dimensions, which means they will be split across multiple devices or accelerators along those dimensions.
- #### `(("multi_head_attention", "(query|key|value)", "b"), P(None))`: - #### `(("multi_head_attention", "(query|key|value)", "b"), P(None))`:
This rule matches the bias tensors (b) of the query, key, and value projections in the multi-head attention module. It specifies that these biases should not be partitioned (indicated by P(None)), meaning they will be replicated across all devices. This rule matches the bias tensors (b) of the query, key, and value projections in the multi-head attention module. It specifies that these biases should not be partitioned (indicated by P(None)), meaning they will be replicated across all devices.
- #### `((r"decoder_layer_[0-9]+", "linear", "w"), P("data", "model"))`: - #### `((r"decoder_layer_[0-9]+", "linear", "w"), P("data", "model"))`:
This rule matches the weight tensors (w) of the linear projections in the decoder layers of the transformer model. The regular expression r"decoder_layer_[0-9]+" matches any parameter path containing "decoder_layer_" followed by a number. These weights are partitioned along the "data" and "model" dimensions. This rule matches the weight tensors (w) of the linear projections in the decoder layers of the transformer model. The regular expression r"decoder_layer_[0-9]+" matches any parameter path containing "decoder_layer_" followed by a number. These weights are partitioned along the "data" and "model" dimensions.
- #### `((r"decoder_layer_[0-9]+", "linear", "b"), P(None))`: - #### `((r"decoder_layer_[0-9]+", "linear", "b"), P(None))`:
Similar to the previous rule, but it matches the bias tensors (b) of the linear projections in the decoder layers, and these biases are not partitioned. Similar to the previous rule, but it matches the bias tensors (b) of the linear projections in the decoder layers, and these biases are not partitioned.
- Rules for partitioning the parameters of layer normalization (layer_norm, rms_norm) and router (router) modules are also included. - Rules for partitioning the parameters of layer normalization (layer_norm, rms_norm) and router (router) modules are also included.
- Rules for partitioning the parameters of the Mixture of Experts (MoE) module, including the `linear projections (linear, linear_v, linear_1)` and normalization layers `(layer_norm, rms_norm)`. - Rules for partitioning the parameters of the Mixture of Experts (MoE) module, including the `linear projections (linear, linear_v, linear_1)` and normalization layers `(layer_norm, rms_norm)`.
These partitioning rules aim to distribute the computationally intensive operations, such as matrix multiplications, across multiple devices, while replicating smaller tensors (like biases) to reduce communication overhead. These partitioning rules aim to distribute the computationally intensive operations, such as matrix multiplications, across multiple devices, while replicating smaller tensors (like biases) to reduce communication overhead.
By applying these partitioning rules, the model can take advantage of the combined memory and computational resources of multiple devices, enabling training of larger models or processing of larger batch sizes. By applying these partitioning rules, the model can take advantage of the combined memory and computational resources of multiple devices, enabling training of larger models or processing of larger batch sizes.
### LM_PARTITION_RULES - ### `LM_PARTITION_RULES`
`LM_PARTITION_RULES` is a list of tuples that define the partitioning rules for the parameters of the language model component in the codebase. These rules are used to specify how the parameters of the language model should be partitioned (reshaped) across multiple devices or accelerators for efficient distributed training. `LM_PARTITION_RULES` is a list of tuples that define the partitioning rules for the parameters of the language model component in the codebase. These rules are used to specify how the parameters of the language model should be partitioned (reshaped) across multiple devices or accelerators for efficient distributed training.
The `LM_PARTITION_RULES` list contains the following rules: The `LM_PARTITION_RULES` list contains the following rules:
- #### `(("language_model", "positional_embeddings"), P(None, ("data", "model")))`: - #### `(("language_model", "positional_embeddings"), P(None, ("data", "model")))`:
This rule matches the positional embeddings tensor in the language model module. The PartitionSpec `P(None, ("data", "model"))` specifies that this tensor should be partitioned along the "data" and "model" dimensions, but not partitioned along the leading dimension (represented by None). This means that the positional embeddings will be split across multiple devices along the "data" and "model" dimensions, but replicated along the leading dimension (e.g., batch dimension). This rule matches the positional embeddings tensor in the language model module. The PartitionSpec `P(None, ("data", "model"))` specifies that this tensor should be partitioned along the "data" and "model" dimensions, but not partitioned along the leading dimension (represented by None). This means that the positional embeddings will be split across multiple devices along the "data" and "model" dimensions, but replicated along the leading dimension (e.g., batch dimension).
- #### `(("language_model", "in_out_embed", "embeddings"), P(None, ("data", "model")))`: - #### `(("language_model", "in_out_embed", "embeddings"), P(None, ("data", "model")))`:
This rule matches the embeddings tensor of the InOutEmbed module (used for input and output embeddings) in the language model. Similar to the previous rule, it specifies that this tensor should be partitioned along the "data" and "model" dimensions, while being replicated along the leading dimension. This rule matches the embeddings tensor of the InOutEmbed module (used for input and output embeddings) in the language model. Similar to the previous rule, it specifies that this tensor should be partitioned along the "data" and "model" dimensions, while being replicated along the leading dimension.
- #### `(("language_model", "rms_norm"), P(None))`: - #### `(("language_model", "rms_norm"), P(None))`:
This rule matches the parameters of the RMSNorm layer in the language model. The PartitionSpec P(None) indicates that these parameters should not be partitioned at all and should be replicated across all devices. This rule matches the parameters of the RMSNorm layer in the language model. The PartitionSpec P(None) indicates that these parameters should not be partitioned at all and should be replicated across all devices.
By applying these partitioning rules, the language model's parameters are reshaped and distributed across multiple devices in a way that aims to balance the computational load and memory usage. The input and output embeddings, which are typically large tensors, are partitioned along the "data" and "model" dimensions to distribute their storage and computations. At the same time, smaller tensors like the normalization layer parameters are replicated across all devices to minimize communication overhead. By applying these partitioning rules, the language model's parameters are reshaped and distributed across multiple devices in a way that aims to balance the computational load and memory usage. The input and output embeddings, which are typically large tensors, are partitioned along the "data" and "model" dimensions to distribute their storage and computations. At the same time, smaller tensors like the normalization layer parameters are replicated across all devices to minimize communication overhead.
## :page_with_curl: QuantizedWeight8bit ## :page_with_curl: QuantizedWeight8bit
@ -138,9 +142,9 @@ The QuantizedWeight8bit class is a data structure that represents quantized weig
The QuantizedWeight8bit class has two main attributes: The QuantizedWeight8bit class has two main attributes:
**weight** : This is a NumPy array that holds the quantized weight values, represented as 8-bit integers. weight : This is a NumPy array that holds the quantized weight values, represented as 8-bit integers.
**scales** : This is a NumPy array that holds the scaling factors associated with each quantized weight. scales : This is a NumPy array that holds the scaling factors associated with each quantized weight.
During the model initialization or loading phase, the original 32-bit floating-point weights are quantized to 8-bit integers and packed into QuantizedWeight8bit instances. During the model initialization or loading phase, the original 32-bit floating-point weights are quantized to 8-bit integers and packed into QuantizedWeight8bit instances.
When performing computations in the neural network, such as linear transformations or convolutions, the quantized weights are used instead of the original 32-bit weights. This is done by applying the scaling factors stored in the scales attribute to recover approximate values of the original weights. When performing computations in the neural network, such as linear transformations or convolutions, the quantized weights are used instead of the original 32-bit weights. This is done by applying the scaling factors stored in the scales attribute to recover approximate values of the original weights.