mirror of
https://github.com/xai-org/grok-1.git
synced 2025-07-07 08:15:00 +03:00
Compare commits
2 Commits
0b0576a0a8
...
download-i
Author | SHA1 | Date | |
---|---|---|---|
48094e5c68 | |||
1dfcf10e9d |
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +0,0 @@
|
|||||||
checkpoints/*
|
|
||||||
!checkpoints/README.md
|
|
21
README.md
21
README.md
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
||||||
|
|
||||||
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
|
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights)
|
||||||
|
|
||||||
Then, run
|
Then, run
|
||||||
|
|
||||||
@ -18,31 +18,14 @@ The script loads the checkpoint and samples from the model on a test input.
|
|||||||
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
||||||
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
||||||
|
|
||||||
# Model Specifications
|
|
||||||
|
|
||||||
Grok-1 is currently designed with the following specifications:
|
|
||||||
|
|
||||||
- **Parameters:** 314B
|
|
||||||
- **Architecture:** Mixture of 8 Experts (MoE)
|
|
||||||
- **Experts Utilization:** 2 experts used per token
|
|
||||||
- **Layers:** 64
|
|
||||||
- **Attention Heads:** 48 for queries, 8 for keys/values
|
|
||||||
- **Embedding Size:** 6,144
|
|
||||||
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
|
|
||||||
- **Additional Features:**
|
|
||||||
- Rotary embeddings (RoPE)
|
|
||||||
- Supports activation sharding and 8-bit quantization
|
|
||||||
- **Maximum Sequence Length (context):** 8,192 tokens
|
|
||||||
|
|
||||||
# Downloading the weights
|
# Downloading the weights
|
||||||
|
|
||||||
You can download the weights using a torrent client and this magnet link:
|
You can download the weights using a torrent client and this magnet link:
|
||||||
|
|
||||||
```
|
```
|
||||||
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
||||||
```
|
```
|
||||||
|
|
||||||
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
|
or directly using HuggingFace:
|
||||||
```
|
```
|
||||||
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
||||||
pip install huggingface_hub[hf_transfer]
|
pip install huggingface_hub[hf_transfer]
|
||||||
|
@ -26,10 +26,6 @@ import tempfile
|
|||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
from concurrent.futures import ThreadPoolExecutor, wait
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
# For get_load_path_str
|
|
||||||
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.experimental import multihost_utils
|
from jax.experimental import multihost_utils
|
||||||
@ -108,21 +104,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
else:
|
else:
|
||||||
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
||||||
wait(fs)
|
wait(fs)
|
||||||
|
return [f.result() for f in fs]
|
||||||
# return [f.result() for f in fs]
|
|
||||||
"""
|
|
||||||
Improve error reporting in load_tensors by catching exceptions within the futures-
|
|
||||||
and logging detailed information about the failure.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for future in fs:
|
|
||||||
try:
|
|
||||||
result = future.result()
|
|
||||||
results.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load tensor: {e}")
|
|
||||||
raise
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def path_tuple_to_string(path: tuple) -> str:
|
def path_tuple_to_string(path: tuple) -> str:
|
||||||
@ -137,22 +119,6 @@ def path_tuple_to_string(path: tuple) -> str:
|
|||||||
return "/".join(pieces)
|
return "/".join(pieces)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
For get_load_path_str(),
|
|
||||||
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
|
||||||
"""
|
|
||||||
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
|
||||||
def get_load_path_str_cached(
|
|
||||||
init_path_str: str,
|
|
||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
||||||
load_exclude_rules: Optional[list[str]] = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
return get_load_path_str(
|
|
||||||
init_path_str,
|
|
||||||
load_rename_rules,
|
|
||||||
load_exclude_rules
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_load_path_str(
|
def get_load_path_str(
|
||||||
init_path_str: str,
|
init_path_str: str,
|
||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
@ -191,7 +157,7 @@ def replace_with_load_state(
|
|||||||
data_model_shards = math.prod(mesh_config)
|
data_model_shards = math.prod(mesh_config)
|
||||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||||
init_path_str = path_tuple_to_string(init_path)
|
init_path_str = path_tuple_to_string(init_path)
|
||||||
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
||||||
if load_path_str is None:
|
if load_path_str is None:
|
||||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||||
replaced.append(tensor)
|
replaced.append(tensor)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
dm_haiku==0.0.12
|
dm_haiku==0.0.12
|
||||||
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
sentencepiece==0.2.0
|
sentencepiece==0.2.0
|
||||||
|
Reference in New Issue
Block a user