Minor fixes and clean-up

This commit is contained in:
Toby Pohlen 2024-03-17 04:29:18 -07:00
parent ce81a1668e
commit e288ce35cb
6 changed files with 10 additions and 16 deletions

View File

@ -23,7 +23,6 @@ import re
import shutil
import tempfile
from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial
from typing import Any, Optional
import jax
@ -143,7 +142,6 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1),
) -> Any:
# re_pattern, re_substitue
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
@ -177,7 +175,6 @@ def restore(
checkpoint_path: str,
state_shapes: Any,
mesh,
local_mesh_config,
between_hosts_config,
params_only,
state_sharding,

View File

@ -28,7 +28,6 @@ from jax.lax import with_sharding_constraint as pjit_sharding_constraint
from jax.sharding import PartitionSpec
from jax.sharding import PartitionSpec as P
config.update("jax_spmd_mode", "allow_all")
logger = logging.getLogger(__name__)
@ -226,7 +225,6 @@ class Router(hk.Module):
def compute_routing_prob(
self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int
):
"""During inference, we sort the inputs depending on which expert they get routed to."""
return self._compute_routing_prob(inputs, padding_mask, num_experts)
@hk.transparent
@ -1171,7 +1169,7 @@ class LanguageModelConfig:
# We cannot specify [] as a default value (it is mutable), hence None.
model_config = self.model
assert self.init_scale_override is None, (
"Overriding model " + "initialize scale is supported only for predefined models."
"Overriding model initialize scale is supported only for predefined models."
)
if self.model_size == 0:
self.model_size = model_config.emb_size
@ -1226,8 +1224,6 @@ class LanguageModel(hk.Module):
config = self.config
input_mask = jnp.greater(tokens, config.pad_token)
# Shift right by 1:
pad_width = ((0, 0), (1, 0))
# Embed the input tokens and positions.
in_out_embed = InOutEmbed(

View File

@ -8,7 +8,6 @@ ignore = [
"E731",
"E741",
"F405",
# Re-enable these later on:
"E402",
"F403",
]

View File

@ -1,4 +1,5 @@
dm_haiku==0.0.12
jax==0.4.25
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]==0.4.25
numpy==1.26.4
sentencepiece==0.2.0

8
run.py
View File

@ -14,12 +14,14 @@
import logging
from model import LanguageModelConfig, TransformerConfig
from runners import InferenceRunner, ModelRunner
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/"
# This is required in order for unpickling to work.
QuantizedWeight8bit = QW8Bit
def main():
grok_1_model = LanguageModelConfig(

View File

@ -44,7 +44,6 @@ from model import (
logger = logging.getLogger(__name__)
rank_logger = logging.getLogger("rank")
TOP_K = 8
@ -71,7 +70,8 @@ def insert_slice(memory: Memory, slice, length, i):
],
)
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), memory, slice)
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
memory, slice)
def pad_to_size(x, size):
@ -239,7 +239,6 @@ class ModelRunner:
checkpoint_path=self.checkpoint_path,
state_shapes=state_shapes,
mesh=self.mesh,
local_mesh_config=self.local_mesh_config,
between_hosts_config=self.between_hosts_config,
state_sharding=self.state_sharding,
init_state=init_state,