This commit is contained in:
Dionei Beilke dos Santos 2024-03-19 13:38:18 -03:00 committed by GitHub
commit e4c3d3f814
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 8 deletions

View File

@ -4,21 +4,27 @@ This repository contains JAX example code for loading and running the Grok-1 ope
Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights) Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)
Then, run ## 1. Installation
```shell 1. Install the project dependencies
```bash
pip install -r requirements.txt pip install -r requirements.txt
```
2. Run the project
```bash
python run.py python run.py
``` ```
to test the code.
The script loads the checkpoint and samples from the model on a test input. The script loads the checkpoint and samples from the model on a test input.
Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code. Due to the large size of the model (314 Billion parameters), a machine with enough GPU memory is required to test the model with the example code.
The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model. The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.
# Model Specifications ## 2. Model Specifications
Grok-1 is currently designed with the following specifications: Grok-1 is currently designed with the following specifications:
@ -33,8 +39,9 @@ Grok-1 is currently designed with the following specifications:
- Rotary embeddings (RoPE) - Rotary embeddings (RoPE)
- Supports activation sharding and 8-bit quantization - Supports activation sharding and 8-bit quantization
- **Maximum Sequence Length (context):** 8,192 tokens - **Maximum Sequence Length (context):** 8,192 tokens
- **TPU/GPU:** NVIDIA/AMD supported only
# Downloading the weights ## 3. Downloading the weights
You can download the weights using a torrent client and this magnet link: You can download the weights using a torrent client and this magnet link:

View File

@ -1,4 +1,4 @@
dm_haiku==0.0.12 dm-haiku==0.0.12
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4 numpy==1.26.4
sentencepiece==0.2.0 sentencepiece==0.2.0