diff --git a/README.md b/README.md index f501a07..46aa03b 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,27 @@ This repository contains JAX example code for loading and running the Grok-1 ope Make sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights) -Then, run +## 1. Installation -```shell +1. Install the project dependencies + +```bash pip install -r requirements.txt +``` + +2. Run the project + +```bash python run.py ``` -to test the code. - The script loads the checkpoint and samples from the model on a test input. -Due to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code. +Due to the large size of the model (314 Billion parameters), a machine with enough GPU memory is required to test the model with the example code. + The implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model. -# Model Specifications +## 2. Model Specifications Grok-1 is currently designed with the following specifications: @@ -33,8 +39,9 @@ Grok-1 is currently designed with the following specifications: - Rotary embeddings (RoPE) - Supports activation sharding and 8-bit quantization - **Maximum Sequence Length (context):** 8,192 tokens +- **TPU/GPU:** NVIDIA/AMD supported only -# Downloading the weights +## 3. Downloading the weights You can download the weights using a torrent client and this magnet link: diff --git a/requirements.txt b/requirements.txt index f6d124e..09e9e15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -dm_haiku==0.0.12 -jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +dm-haiku==0.0.12 +jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html numpy==1.26.4 sentencepiece==0.2.0