diff --git a/checkpoint.py b/checkpoint.py index 07aafab..5068635 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -23,7 +23,6 @@ import re import shutil import tempfile from concurrent.futures import ThreadPoolExecutor, wait -from functools import partial from typing import Any, Optional import jax @@ -143,7 +142,6 @@ def replace_with_load_state( load_exclude_rules: Optional[list[str]] = None, mesh_config: tuple = (1, 1), ) -> Any: - # re_pattern, re_substitue flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} @@ -177,7 +175,6 @@ def restore( checkpoint_path: str, state_shapes: Any, mesh, - local_mesh_config, between_hosts_config, params_only, state_sharding, diff --git a/model.py b/model.py index 45936e1..4401f44 100644 --- a/model.py +++ b/model.py @@ -28,7 +28,6 @@ from jax.lax import with_sharding_constraint as pjit_sharding_constraint from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec as P - config.update("jax_spmd_mode", "allow_all") logger = logging.getLogger(__name__) @@ -226,7 +225,6 @@ class Router(hk.Module): def compute_routing_prob( self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int ): - """During inference, we sort the inputs depending on which expert they get routed to.""" return self._compute_routing_prob(inputs, padding_mask, num_experts) @hk.transparent @@ -1171,7 +1169,7 @@ class LanguageModelConfig: # We cannot specify [] as a default value (it is mutable), hence None. model_config = self.model assert self.init_scale_override is None, ( - "Overriding model " + "initialize scale is supported only for predefined models." + "Overriding model initialize scale is supported only for predefined models." ) if self.model_size == 0: self.model_size = model_config.emb_size @@ -1226,8 +1224,6 @@ class LanguageModel(hk.Module): config = self.config input_mask = jnp.greater(tokens, config.pad_token) - # Shift right by 1: - pad_width = ((0, 0), (1, 0)) # Embed the input tokens and positions. in_out_embed = InOutEmbed( diff --git a/pyproject.toml b/pyproject.toml index dcff38c..aa55016 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ ignore = [ "E731", "E741", "F405", - # Re-enable these later on: "E402", "F403", ] diff --git a/requirements.txt b/requirements.txt index ea28d14..02e1ce6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ dm_haiku==0.0.12 -jax==0.4.25 +-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12_pip]==0.4.25 numpy==1.26.4 sentencepiece==0.2.0 diff --git a/run.py b/run.py index bc0f4e0..2825dcd 100644 --- a/run.py +++ b/run.py @@ -14,12 +14,14 @@ import logging -from model import LanguageModelConfig, TransformerConfig -from runners import InferenceRunner, ModelRunner - +from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit +from runners import InferenceRunner, ModelRunner, sample_from_model CKPT_PATH = "./checkpoints/" +# This is required in order for unpickling to work. +QuantizedWeight8bit = QW8Bit + def main(): grok_1_model = LanguageModelConfig( diff --git a/runners.py b/runners.py index 7f2c125..452c142 100644 --- a/runners.py +++ b/runners.py @@ -44,7 +44,6 @@ from model import ( logger = logging.getLogger(__name__) rank_logger = logging.getLogger("rank") - TOP_K = 8 @@ -71,7 +70,8 @@ def insert_slice(memory: Memory, slice, length, i): ], ) - return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), memory, slice) + return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), + memory, slice) def pad_to_size(x, size): @@ -239,7 +239,6 @@ class ModelRunner: checkpoint_path=self.checkpoint_path, state_shapes=state_shapes, mesh=self.mesh, - local_mesh_config=self.local_mesh_config, between_hosts_config=self.between_hosts_config, state_sharding=self.state_sharding, init_state=init_state,