Compare commits

..

10 Commits

Author SHA1 Message Date
0b0576a0a8 Merge 104affb51a into 7050ed204b 2024-03-21 06:48:52 +05:30
104affb51a Set maxsize to 32MB in lru_cache() for reasonable memory usage 2024-03-21 06:48:22 +05:30
5fc82399bf get_load_path_str() -> get_load_path_str_cached(): Optimizing Regex Operations with Caching 2024-03-19 22:06:40 +05:30
3fd4e7c4d7 Enhanced Error Handling in load_tensors() 2024-03-19 21:58:37 +05:30
7050ed204b Corrected name of package "cuda12-pip" (#194)
The `cuda12-pip` package was wrongly named `cuda12_pip`
in requirements.txt
2024-03-19 08:48:22 -07:00
d6d9447e2d Update huggingface link 2024-03-18 11:40:01 -07:00
7207216386 Create .gitignore for checkpoints (#149)
ignore the checkpoints files
2024-03-18 11:01:17 -07:00
310e19eee2 Corrected checkpoint dir name, download section link 2024-03-18 09:39:02 -07:00
1ff4435d25 Update README with Model Specifications (#27)
Added an overview of the model as discussed in response to #14. 

Adding more info on the the model specs before they proceed to download
the checkpoints should help folks ensure they have the necessary
resources to effectively utilize Grok-1.
2024-03-18 09:36:24 -07:00
b0e77734fe Make download instruction more clear (#155) 2024-03-18 09:11:17 -07:00
4 changed files with 58 additions and 5 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
checkpoints/*
!checkpoints/README.md

View File

@ -2,7 +2,7 @@
This repository contains JAX example code for loading and running the Grok-1 open-weights model.
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights)
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
Then, run
@ -18,14 +18,31 @@ The script loads the checkpoint and samples from the model on a test input.
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
# Model Specifications
Grok-1 is currently designed with the following specifications:
- **Parameters:** 314B
- **Architecture:** Mixture of 8 Experts (MoE)
- **Experts Utilization:** 2 experts used per token
- **Layers:** 64
- **Attention Heads:** 48 for queries, 8 for keys/values
- **Embedding Size:** 6,144
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
- **Additional Features:**
- Rotary embeddings (RoPE)
- Supports activation sharding and 8-bit quantization
- **Maximum Sequence Length (context):** 8,192 tokens
# Downloading the weights
You can download the weights using a torrent client and this magnet link:
```
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
```
or directly using HuggingFace:
or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
```
git clone https://github.com/xai-org/grok-1.git && cd grok-1
pip install huggingface_hub[hf_transfer]

View File

@ -26,6 +26,10 @@ import tempfile
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
# For get_load_path_str
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
from functools import lru_cache
import jax
import numpy as np
from jax.experimental import multihost_utils
@ -104,7 +108,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
else:
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
wait(fs)
return [f.result() for f in fs]
# return [f.result() for f in fs]
"""
Improve error reporting in load_tensors by catching exceptions within the futures-
and logging detailed information about the failure.
"""
results = []
for future in fs:
try:
result = future.result()
results.append(result)
except Exception as e:
logger.error(f"Failed to load tensor: {e}")
raise
return results
def path_tuple_to_string(path: tuple) -> str:
@ -119,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str:
return "/".join(pieces)
"""
For get_load_path_str(),
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
"""
@lru_cache(maxsize=32) # Set maxsize to 32MB
def get_load_path_str_cached(
init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
return get_load_path_str(
init_path_str,
load_rename_rules,
load_exclude_rules
)
def get_load_path_str(
init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
@ -157,7 +191,7 @@ def replace_with_load_state(
data_model_shards = math.prod(mesh_config)
for i, (init_path, tensor) in enumerate(flatten_init):
init_path_str = path_tuple_to_string(init_path)
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
if load_path_str is None:
rank_logger.info(f"Excluded from restore: {init_path_str}.")
replaced.append(tensor)

View File

@ -1,4 +1,4 @@
dm_haiku==0.0.12
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0