mirror of
https://github.com/xai-org/grok-1.git
synced 2025-07-06 07:44:58 +03:00
Compare commits
2 Commits
a105de1c16
...
download-i
Author | SHA1 | Date | |
---|---|---|---|
48094e5c68 | |||
1dfcf10e9d |
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +0,0 @@
|
||||
checkpoints/*
|
||||
!checkpoints/README.md
|
21
README.md
21
README.md
@ -2,7 +2,7 @@
|
||||
|
||||
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
|
||||
|
||||
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
|
||||
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights)
|
||||
|
||||
Then, run
|
||||
|
||||
@ -18,31 +18,14 @@ The script loads the checkpoint and samples from the model on a test input.
|
||||
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
|
||||
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
|
||||
|
||||
# Model Specifications
|
||||
|
||||
Grok-1 is currently designed with the following specifications:
|
||||
|
||||
- **Parameters:** 314B
|
||||
- **Architecture:** Mixture of 8 Experts (MoE)
|
||||
- **Experts Utilization:** 2 experts used per token
|
||||
- **Layers:** 64
|
||||
- **Attention Heads:** 48 for queries, 8 for keys/values
|
||||
- **Embedding Size:** 6,144
|
||||
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
|
||||
- **Additional Features:**
|
||||
- Rotary embeddings (RoPE)
|
||||
- Supports activation sharding and 8-bit quantization
|
||||
- **Maximum Sequence Length (context):** 8,192 tokens
|
||||
|
||||
# Downloading the weights
|
||||
|
||||
You can download the weights using a torrent client and this magnet link:
|
||||
|
||||
```
|
||||
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
||||
```
|
||||
|
||||
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
|
||||
or directly using HuggingFace:
|
||||
```
|
||||
git clone https://github.com/xai-org/grok-1.git && cd grok-1
|
||||
pip install huggingface_hub[hf_transfer]
|
||||
|
@ -26,10 +26,6 @@ import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
# For get_load_path_str
|
||||
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
from functools import lru_cache
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax.experimental import multihost_utils
|
||||
@ -108,21 +104,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
else:
|
||||
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
||||
wait(fs)
|
||||
|
||||
# return [f.result() for f in fs]
|
||||
"""
|
||||
Improve error reporting in load_tensors by catching exceptions within the futures-
|
||||
and logging detailed information about the failure.
|
||||
"""
|
||||
results = []
|
||||
for future in fs:
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tensor: {e}")
|
||||
raise
|
||||
return results
|
||||
return [f.result() for f in fs]
|
||||
|
||||
|
||||
def path_tuple_to_string(path: tuple) -> str:
|
||||
@ -137,22 +119,6 @@ def path_tuple_to_string(path: tuple) -> str:
|
||||
return "/".join(pieces)
|
||||
|
||||
|
||||
"""
|
||||
For get_load_path_str(),
|
||||
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
"""
|
||||
@lru_cache(maxsize=None)
|
||||
def get_load_path_str_cached(
|
||||
init_path_str: str,
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
) -> Optional[str]:
|
||||
return get_load_path_str(
|
||||
init_path_str,
|
||||
load_rename_rules,
|
||||
load_exclude_rules
|
||||
)
|
||||
|
||||
def get_load_path_str(
|
||||
init_path_str: str,
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
@ -191,7 +157,7 @@ def replace_with_load_state(
|
||||
data_model_shards = math.prod(mesh_config)
|
||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||
init_path_str = path_tuple_to_string(init_path)
|
||||
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
||||
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
||||
if load_path_str is None:
|
||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||
replaced.append(tensor)
|
||||
|
@ -1,4 +1,4 @@
|
||||
dm_haiku==0.0.12
|
||||
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
numpy==1.26.4
|
||||
sentencepiece==0.2.0
|
||||
|
Reference in New Issue
Block a user