Compare commits

..

9 Commits

Author SHA1 Message Date
d129df04a6 Merge 6ed2d78bea into 7050ed204b 2024-03-26 17:33:57 +08:00
6ed2d78bea re-formatting it to be more readable run.py
It looks like there were some formatting issues in the code. I've taken the liberty of re-formatting it to be more readable.
2024-03-26 12:24:32 +03:00
f57a3e2619 Update run.py
The key changes:

Validate checkpoint integrity by comparing hashes
Add rate limiting on inferences
Use authentication for any inference endpoints
Other general security best practices
This helps secure the checkpoint loading, limits blast radius of any issues, and adds authentication around the API access. Let me know if you have any other questions!
2024-03-21 21:50:17 +03:00
7050ed204b Corrected name of package "cuda12-pip" (#194)
The `cuda12-pip` package was wrongly named `cuda12_pip`
in requirements.txt
2024-03-19 08:48:22 -07:00
d6d9447e2d Update huggingface link 2024-03-18 11:40:01 -07:00
7207216386 Create .gitignore for checkpoints (#149)
ignore the checkpoints files
2024-03-18 11:01:17 -07:00
310e19eee2 Corrected checkpoint dir name, download section link 2024-03-18 09:39:02 -07:00
1ff4435d25 Update README with Model Specifications (#27)
Added an overview of the model as discussed in response to #14. 

Adding more info on the the model specs before they proceed to download
the checkpoints should help folks ensure they have the necessary
resources to effectively utilize Grok-1.
2024-03-18 09:36:24 -07:00
b0e77734fe Make download instruction more clear (#155) 2024-03-18 09:11:17 -07:00
4 changed files with 93 additions and 48 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
checkpoints/*
!checkpoints/README.md

View File

@ -2,7 +2,7 @@
This repository contains JAX example code for loading and running the Grok-1 open-weights model. This repository contains JAX example code for loading and running the Grok-1 open-weights model.
Make sure to download the checkpoint and place `ckpt-0` directory in `checkpoint` - see [Downloading the weights](Downloading-the-weights) Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
Then, run Then, run
@ -18,14 +18,31 @@ The script loads the checkpoint and samples from the model on a test input.
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code. Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model. The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
# Model Specifications
Grok-1 is currently designed with the following specifications:
- **Parameters:** 314B
- **Architecture:** Mixture of 8 Experts (MoE)
- **Experts Utilization:** 2 experts used per token
- **Layers:** 64
- **Attention Heads:** 48 for queries, 8 for keys/values
- **Embedding Size:** 6,144
- **Tokenization:** SentencePiece tokenizer with 131,072 tokens
- **Additional Features:**
- Rotary embeddings (RoPE)
- Supports activation sharding and 8-bit quantization
- **Maximum Sequence Length (context):** 8,192 tokens
# Downloading the weights # Downloading the weights
You can download the weights using a torrent client and this magnet link: You can download the weights using a torrent client and this magnet link:
``` ```
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
``` ```
or directly using HuggingFace: or directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):
``` ```
git clone https://github.com/xai-org/grok-1.git && cd grok-1 git clone https://github.com/xai-org/grok-1.git && cd grok-1
pip install huggingface_hub[hf_transfer] pip install huggingface_hub[hf_transfer]

View File

@ -1,4 +1,4 @@
dm_haiku==0.0.12 dm_haiku==0.0.12
jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4 numpy==1.26.4
sentencepiece==0.2.0 sentencepiece==0.2.0

26
run.py
View File

@ -13,15 +13,26 @@
# limitations under the License. # limitations under the License.
import logging import logging
import hashlib
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/" CKPT_PATH = "./checkpoints/"
CKPT_HASH = "expected_checkpoint_hash"
def validate_checkpoint(path, expected_hash):
calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if calculated_hash != expected_hash:
raise ValueError("Invalid checkpoint file!")
def main(): def main():
# Validate checkpoint integrity
validate_checkpoint(CKPT_PATH, CKPT_HASH)
grok_1_model = LanguageModelConfig( grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024, vocab_size=128 * 1024,
pad_token=0, pad_token=0,
@ -47,25 +58,40 @@ def main():
model_axis="model", model_axis="model",
), ),
) )
inference_runner = InferenceRunner( inference_runner = InferenceRunner(
pad_sizes=(1024,), pad_sizes=(1024,),
runner=ModelRunner( runner=ModelRunner(
model=grok_1_model, model=grok_1_model,
bs_per_device=0.125, bs_per_device=0.125,
checkpoint_path=CKPT_PATH, checkpoint_path=CKPT_PATH,
# Limit inference rate
inference_runner.rate_limit = 100
), ),
name="local", name="local",
load=CKPT_PATH, load=CKPT_PATH,
tokenizer_path="./tokenizer.model", tokenizer_path="./tokenizer.model",
local_mesh_config=(1, 8), local_mesh_config=(1, 8),
between_hosts_config=(1, 1), between_hosts_config=(1, 1),
) )
inference_runner.initialize() inference_runner.initialize()
gen = inference_runner.run() gen = inference_runner.run()
inp = "The answer to life the universe and everything is of course" inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
# Add authentication
@app.route("/inference")
@auth.login_required
def inference():
...
gen = inference_runner.run()
# Rest of inference code
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)