Compare commits

..

2 Commits

Author SHA1 Message Date
48094e5c68 Update instruction 2024-03-18 09:09:40 -07:00
1dfcf10e9d Add detailed download instruction 2024-03-18 09:08:14 -07:00
3 changed files with 25 additions and 39 deletions

2
.gitignore vendored
View File

@ -1,2 +0,0 @@
checkpoints/*
!checkpoints/README.md

View File

@ -2,7 +2,7 @@
This repository contains JAX example code for loading and running the Grok-1 open-weights model. This repository contains JAX example code for loading and running the Grok-1 open-weights model.
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights) Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights)
Then, run Then, run
@ -18,31 +18,14 @@ The script loads the checkpoint and samples from the model on a test input.
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code. Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model. The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
# Model Specifications
Grok-1 is currently designed with the following specifications:
- **Parameters:** 314B
- **Architecture:** Mixture of 8 Experts (MoE)
- **Experts Utilization:** 2 experts used per token
- **Layers:** 64
- **Attention Heads:** 48 for queries, 8 for keys/values
- **Embedding Size:** 6,144
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
- **Additional Features:**
- Rotary embeddings (RoPE)
- Supports activation sharding and 8-bit quantization
- **Maximum Sequence Length (context):** 8,192 tokens
# Downloading the weights # Downloading the weights
You can download the weights using a torrent client and this magnet link: You can download the weights using a torrent client and this magnet link:
``` ```
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
``` ```
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1): or directly using HuggingFace:
``` ```
git clone https://github.com/xai-org/grok-1.git && cd grok-1 git clone https://github.com/xai-org/grok-1.git && cd grok-1
pip install huggingface_hub[hf_transfer] pip install huggingface_hub[hf_transfer]

View File

@ -1002,7 +1002,7 @@ class DenseBlock(hk.Module):
sharding=P("model", "data"), sharding=P("model", "data"),
mesh=self.mesh, mesh=self.mesh,
shard_axis=1, shard_axis=1,
)(h_w1 * h_v) # TODO: Document why this isn't sequential and whether it should be. )(h_w1 * h_v)
return h_dense return h_dense
@ -1036,10 +1036,13 @@ class DecoderLayer(hk.Module):
) -> DecoderOutput: ) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences.""" """Transforms input embedding sequences to output embedding sequences."""
sharding = P(self.data_axis, None) def layer_norm(x):
return hk_rms_norm(x)
if self.shard_activations: if self.shard_activations:
sharding = P(self.data_axis, None, self.model_axis) sharding = P(self.data_axis, None, self.model_axis)
else:
sharding = P(self.data_axis, None)
h = with_sharding_constraint(inputs, sharding) h = with_sharding_constraint(inputs, sharding)
attn_output = MHABlock( attn_output = MHABlock(
@ -1051,8 +1054,8 @@ class DecoderLayer(hk.Module):
data_axis=self.data_axis, data_axis=self.data_axis,
model_axis=self.model_axis, model_axis=self.model_axis,
)(layer_norm(h), mask, layer_memory) )(layer_norm(h), mask, layer_memory)
h_attn = attn_output.embeddings h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn) h_attn = layer_norm(h_attn)
h += h_attn h += h_attn
h = with_sharding_constraint(h, sharding) h = with_sharding_constraint(h, sharding)
@ -1162,17 +1165,15 @@ class LanguageModelConfig:
_initialized = False _initialized = False
def initialize(self): def initialize(self):
# We cannot specify [] as a default value (it is mutable), hence None.
model_config = self.model model_config = self.model
assert self.init_scale_override is None, (
if self.model is None: # We cannot specify [] as a default value (it is mutable), hence None. "Overriding model initialize scale is supported only for predefined models."
raise ValueError("Model configuration is not set.") )
if self.init_scale_override is not None:
raise ValueError("Overriding model initialize scale is supported only for predefined models.")
if self.model_size == 0: if self.model_size == 0:
self.model_size = model_config.emb_size self.model_size = model_config.emb_size
assert self.model is not None, "Model could not be initialized."
self._initialized = True self._initialized = True
return self return self
def make(self, *args, **kwargs): def make(self, *args, **kwargs):
@ -1193,7 +1194,7 @@ class LanguageModelConfig:
return LM_PARTITION_RULES + self.model.partition_rules() return LM_PARTITION_RULES + self.model.partition_rules()
def layer_norm(x): def layer_norm(x, model):
return hk_rms_norm(x) return hk_rms_norm(x)
@ -1212,12 +1213,17 @@ class LanguageModel(hk.Module):
tokens: jax.Array, tokens: jax.Array,
memory: Optional[Memory] = None, memory: Optional[Memory] = None,
*, *,
batch: Dict[str, jax.Array] = {},
last_hid_only: bool = False, last_hid_only: bool = False,
length: Optional[jax.Array] = None, length: Optional[jax.Array] = None,
) -> LanguageModelOutput: ) -> LanguageModelOutput:
"""Forward pass, producing a sequence of logits.""" """Forward pass, producing a sequence of logits."""
del batch # Unused.
config = self.config config = self.config
input_mask = jnp.greater(tokens, config.pad_token)
# Embed the input tokens and positions. # Embed the input tokens and positions.
in_out_embed = InOutEmbed( in_out_embed = InOutEmbed(
self.config.vocab_size, self.config.vocab_size,
@ -1229,7 +1235,6 @@ class LanguageModel(hk.Module):
input_embeddings, P("data", None, self.model.model_axis) input_embeddings, P("data", None, self.model.model_axis)
) )
input_embeddings *= config.embedding_multiplier_scale input_embeddings *= config.embedding_multiplier_scale
input_mask = jnp.not_equal(tokens, config.pad_token)
model_output = self.model( model_output = self.model(
input_embeddings, input_embeddings,
@ -1237,15 +1242,15 @@ class LanguageModel(hk.Module):
memory=memory, memory=memory,
) # [B, T, D] ) # [B, T, D]
embeddings, model_state = model_output.embeddings, model_output.memory embeddings, model_state = model_output.embeddings, model_output.memory
if embeddings.dtype != self.fprop_dtype:
raise ValueError(f"Expected forward propagation dtype {self.fprop_dtype} but got {embeddings.dtype} in embeddings.")
if self.model.shard_activations: if self.model.shard_activations:
embeddings = with_sharding_constraint(embeddings, P("data", None, self.model.model_axis)) embeddings = with_sharding_constraint(
embeddings, P("data", None, self.model.model_axis)
)
else: else:
embeddings = with_sharding_constraint(embeddings, P("data", None)) embeddings = with_sharding_constraint(embeddings, P("data", None))
rank_logger.debug(f"Final embedding shape: {embeddings.shape}") rank_logger.debug(f"Final embedding shape: {embeddings.shape}")
embeddings = layer_norm(embeddings) embeddings = layer_norm(embeddings, self.model)
assert embeddings.dtype == self.fprop_dtype
if last_hid_only: if last_hid_only:
last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0) last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)