mirror of
https://github.com/xai-org/grok-1.git
synced 2025-07-07 00:04:59 +03:00
Compare commits
10 Commits
download-i
...
0b0576a0a8
Author | SHA1 | Date | |
---|---|---|---|
0b0576a0a8 | |||
104affb51a | |||
5fc82399bf | |||
3fd4e7c4d7 | |||
7050ed204b | |||
d6d9447e2d | |||
7207216386 | |||
310e19eee2 | |||
1ff4435d25 | |||
b0e77734fe |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
checkpoints/*
|
||||||
|
!checkpoints/README.md
|
27
README.md
27
README.md
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
||||||
|
|
||||||
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint`.
|
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
|
||||||
|
|
||||||
Then, run
|
Then, run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@ -17,13 +18,37 @@ The script loads the checkpoint and samples from the model on a test input.
|
|||||||
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
||||||
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
||||||
|
|
||||||
|
# Model Specifications
|
||||||
|
|
||||||
|
Grok-1 is currently designed with the following specifications:
|
||||||
|
|
||||||
|
- **Parameters:** 314B
|
||||||
|
- **Architecture:** Mixture of 8 Experts (MoE)
|
||||||
|
- **Experts Utilization:** 2 experts used per token
|
||||||
|
- **Layers:** 64
|
||||||
|
- **Attention Heads:** 48 for queries, 8 for keys/values
|
||||||
|
- **Embedding Size:** 6,144
|
||||||
|
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
|
||||||
|
- **Additional Features:**
|
||||||
|
- Rotary embeddings (RoPE)
|
||||||
|
- Supports activation sharding and 8-bit quantization
|
||||||
|
- **Maximum Sequence Length (context):** 8,192 tokens
|
||||||
|
|
||||||
# Downloading the weights
|
# Downloading the weights
|
||||||
|
|
||||||
You can download the weights using a torrent client and this magnet link:
|
You can download the weights using a torrent client and this magnet link:
|
||||||
|
|
||||||
```
|
```
|
||||||
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
||||||
```
|
```
|
||||||
|
|
||||||
|
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
|
||||||
|
```
|
||||||
|
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
||||||
|
pip install huggingface_hub[hf_transfer]
|
||||||
|
huggingface-cli download xai-org/grok-1 --repo-type model --include ckpt-0/* --local-dir checkpoints --local-dir-use-symlinks False
|
||||||
|
```
|
||||||
|
|
||||||
# License
|
# License
|
||||||
|
|
||||||
The code and associated Grok-1 weights in this release are licensed under the
|
The code and associated Grok-1 weights in this release are licensed under the
|
||||||
|
@ -26,6 +26,10 @@ import tempfile
|
|||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
from concurrent.futures import ThreadPoolExecutor, wait
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# For get_load_path_str
|
||||||
|
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.experimental import multihost_utils
|
from jax.experimental import multihost_utils
|
||||||
@ -104,7 +108,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
else:
|
else:
|
||||||
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
||||||
wait(fs)
|
wait(fs)
|
||||||
return [f.result() for f in fs]
|
|
||||||
|
# return [f.result() for f in fs]
|
||||||
|
"""
|
||||||
|
Improve error reporting in load_tensors by catching exceptions within the futures-
|
||||||
|
and logging detailed information about the failure.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for future in fs:
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load tensor: {e}")
|
||||||
|
raise
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def path_tuple_to_string(path: tuple) -> str:
|
def path_tuple_to_string(path: tuple) -> str:
|
||||||
@ -119,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str:
|
|||||||
return "/".join(pieces)
|
return "/".join(pieces)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
For get_load_path_str(),
|
||||||
|
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||||
|
"""
|
||||||
|
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||||
|
def get_load_path_str_cached(
|
||||||
|
init_path_str: str,
|
||||||
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
return get_load_path_str(
|
||||||
|
init_path_str,
|
||||||
|
load_rename_rules,
|
||||||
|
load_exclude_rules
|
||||||
|
)
|
||||||
|
|
||||||
def get_load_path_str(
|
def get_load_path_str(
|
||||||
init_path_str: str,
|
init_path_str: str,
|
||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
@ -157,7 +191,7 @@ def replace_with_load_state(
|
|||||||
data_model_shards = math.prod(mesh_config)
|
data_model_shards = math.prod(mesh_config)
|
||||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||||
init_path_str = path_tuple_to_string(init_path)
|
init_path_str = path_tuple_to_string(init_path)
|
||||||
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
||||||
if load_path_str is None:
|
if load_path_str is None:
|
||||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||||
replaced.append(tensor)
|
replaced.append(tensor)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
dm_haiku==0.0.12
|
dm_haiku==0.0.12
|
||||||
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
sentencepiece==0.2.0
|
sentencepiece==0.2.0
|
||||||
|
Reference in New Issue
Block a user