mirror of
https://github.com/xai-org/grok-1.git
synced 2025-09-21 18:59:23 +03:00
Compare commits
8 Commits
download-i
...
03df01959d
Author | SHA1 | Date | |
---|---|---|---|
03df01959d | |||
7050ed204b | |||
1a0ba385eb | |||
d6d9447e2d | |||
7207216386 | |||
310e19eee2 | |||
1ff4435d25 | |||
b0e77734fe |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
checkpoints/*
|
||||||
|
!checkpoints/README.md
|
21
README.md
21
README.md
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
||||||
|
|
||||||
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights)
|
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
|
||||||
|
|
||||||
Then, run
|
Then, run
|
||||||
|
|
||||||
@ -18,14 +18,31 @@ The script loads the checkpoint and samples from the model on a test input.
|
|||||||
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
||||||
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
||||||
|
|
||||||
|
# Model Specifications
|
||||||
|
|
||||||
|
Grok-1 is currently designed with the following specifications:
|
||||||
|
|
||||||
|
- **Parameters:** 314B
|
||||||
|
- **Architecture:** Mixture of 8 Experts (MoE)
|
||||||
|
- **Experts Utilization:** 2 experts used per token
|
||||||
|
- **Layers:** 64
|
||||||
|
- **Attention Heads:** 48 for queries, 8 for keys/values
|
||||||
|
- **Embedding Size:** 6,144
|
||||||
|
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
|
||||||
|
- **Additional Features:**
|
||||||
|
- Rotary embeddings (RoPE)
|
||||||
|
- Supports activation sharding and 8-bit quantization
|
||||||
|
- **Maximum Sequence Length (context):** 8,192 tokens
|
||||||
|
|
||||||
# Downloading the weights
|
# Downloading the weights
|
||||||
|
|
||||||
You can download the weights using a torrent client and this magnet link:
|
You can download the weights using a torrent client and this magnet link:
|
||||||
|
|
||||||
```
|
```
|
||||||
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
||||||
```
|
```
|
||||||
|
|
||||||
or directly using HuggingFace:
|
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
|
||||||
```
|
```
|
||||||
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
||||||
pip install huggingface_hub[hf_transfer]
|
pip install huggingface_hub[hf_transfer]
|
||||||
|
41
model.py
41
model.py
@ -1002,7 +1002,7 @@ class DenseBlock(hk.Module):
|
|||||||
sharding=P("model", "data"),
|
sharding=P("model", "data"),
|
||||||
mesh=self.mesh,
|
mesh=self.mesh,
|
||||||
shard_axis=1,
|
shard_axis=1,
|
||||||
)(h_w1 * h_v)
|
)(h_w1 * h_v) # TODO: Document why this isn't sequential and whether it should be.
|
||||||
|
|
||||||
return h_dense
|
return h_dense
|
||||||
|
|
||||||
@ -1036,13 +1036,10 @@ class DecoderLayer(hk.Module):
|
|||||||
) -> DecoderOutput:
|
) -> DecoderOutput:
|
||||||
"""Transforms input embedding sequences to output embedding sequences."""
|
"""Transforms input embedding sequences to output embedding sequences."""
|
||||||
|
|
||||||
def layer_norm(x):
|
sharding = P(self.data_axis, None)
|
||||||
return hk_rms_norm(x)
|
|
||||||
|
|
||||||
if self.shard_activations:
|
if self.shard_activations:
|
||||||
sharding = P(self.data_axis, None, self.model_axis)
|
sharding = P(self.data_axis, None, self.model_axis)
|
||||||
else:
|
|
||||||
sharding = P(self.data_axis, None)
|
|
||||||
h = with_sharding_constraint(inputs, sharding)
|
h = with_sharding_constraint(inputs, sharding)
|
||||||
|
|
||||||
attn_output = MHABlock(
|
attn_output = MHABlock(
|
||||||
@ -1054,8 +1051,8 @@ class DecoderLayer(hk.Module):
|
|||||||
data_axis=self.data_axis,
|
data_axis=self.data_axis,
|
||||||
model_axis=self.model_axis,
|
model_axis=self.model_axis,
|
||||||
)(layer_norm(h), mask, layer_memory)
|
)(layer_norm(h), mask, layer_memory)
|
||||||
h_attn = attn_output.embeddings
|
|
||||||
|
|
||||||
|
h_attn = attn_output.embeddings
|
||||||
h_attn = layer_norm(h_attn)
|
h_attn = layer_norm(h_attn)
|
||||||
h += h_attn
|
h += h_attn
|
||||||
h = with_sharding_constraint(h, sharding)
|
h = with_sharding_constraint(h, sharding)
|
||||||
@ -1165,15 +1162,17 @@ class LanguageModelConfig:
|
|||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
# We cannot specify [] as a default value (it is mutable), hence None.
|
|
||||||
model_config = self.model
|
model_config = self.model
|
||||||
assert self.init_scale_override is None, (
|
|
||||||
"Overriding model initialize scale is supported only for predefined models."
|
if self.model is None: # We cannot specify [] as a default value (it is mutable), hence None.
|
||||||
)
|
raise ValueError("Model configuration is not set.")
|
||||||
|
if self.init_scale_override is not None:
|
||||||
|
raise ValueError("Overriding model initialize scale is supported only for predefined models.")
|
||||||
|
|
||||||
if self.model_size == 0:
|
if self.model_size == 0:
|
||||||
self.model_size = model_config.emb_size
|
self.model_size = model_config.emb_size
|
||||||
assert self.model is not None, "Model could not be initialized."
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def make(self, *args, **kwargs):
|
def make(self, *args, **kwargs):
|
||||||
@ -1194,7 +1193,7 @@ class LanguageModelConfig:
|
|||||||
return LM_PARTITION_RULES + self.model.partition_rules()
|
return LM_PARTITION_RULES + self.model.partition_rules()
|
||||||
|
|
||||||
|
|
||||||
def layer_norm(x, model):
|
def layer_norm(x):
|
||||||
return hk_rms_norm(x)
|
return hk_rms_norm(x)
|
||||||
|
|
||||||
|
|
||||||
@ -1213,17 +1212,12 @@ class LanguageModel(hk.Module):
|
|||||||
tokens: jax.Array,
|
tokens: jax.Array,
|
||||||
memory: Optional[Memory] = None,
|
memory: Optional[Memory] = None,
|
||||||
*,
|
*,
|
||||||
batch: Dict[str, jax.Array] = {},
|
|
||||||
last_hid_only: bool = False,
|
last_hid_only: bool = False,
|
||||||
length: Optional[jax.Array] = None,
|
length: Optional[jax.Array] = None,
|
||||||
) -> LanguageModelOutput:
|
) -> LanguageModelOutput:
|
||||||
"""Forward pass, producing a sequence of logits."""
|
"""Forward pass, producing a sequence of logits."""
|
||||||
del batch # Unused.
|
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
||||||
input_mask = jnp.greater(tokens, config.pad_token)
|
|
||||||
|
|
||||||
# Embed the input tokens and positions.
|
# Embed the input tokens and positions.
|
||||||
in_out_embed = InOutEmbed(
|
in_out_embed = InOutEmbed(
|
||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
@ -1235,6 +1229,7 @@ class LanguageModel(hk.Module):
|
|||||||
input_embeddings, P("data", None, self.model.model_axis)
|
input_embeddings, P("data", None, self.model.model_axis)
|
||||||
)
|
)
|
||||||
input_embeddings *= config.embedding_multiplier_scale
|
input_embeddings *= config.embedding_multiplier_scale
|
||||||
|
input_mask = jnp.not_equal(tokens, config.pad_token)
|
||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_embeddings,
|
input_embeddings,
|
||||||
@ -1242,15 +1237,15 @@ class LanguageModel(hk.Module):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
) # [B, T, D]
|
) # [B, T, D]
|
||||||
embeddings, model_state = model_output.embeddings, model_output.memory
|
embeddings, model_state = model_output.embeddings, model_output.memory
|
||||||
|
if embeddings.dtype != self.fprop_dtype:
|
||||||
|
raise ValueError(f"Expected forward propagation dtype {self.fprop_dtype} but got {embeddings.dtype} in embeddings.")
|
||||||
|
|
||||||
if self.model.shard_activations:
|
if self.model.shard_activations:
|
||||||
embeddings = with_sharding_constraint(
|
embeddings = with_sharding_constraint(embeddings, P("data", None, self.model.model_axis))
|
||||||
embeddings, P("data", None, self.model.model_axis)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
embeddings = with_sharding_constraint(embeddings, P("data", None))
|
embeddings = with_sharding_constraint(embeddings, P("data", None))
|
||||||
rank_logger.debug(f"Final embedding shape: {embeddings.shape}")
|
rank_logger.debug(f"Final embedding shape: {embeddings.shape}")
|
||||||
embeddings = layer_norm(embeddings, self.model)
|
embeddings = layer_norm(embeddings)
|
||||||
assert embeddings.dtype == self.fprop_dtype
|
|
||||||
|
|
||||||
if last_hid_only:
|
if last_hid_only:
|
||||||
last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)
|
last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
dm_haiku==0.0.12
|
dm_haiku==0.0.12
|
||||||
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
sentencepiece==0.2.0
|
sentencepiece==0.2.0
|
||||||
|
Reference in New Issue
Block a user