Minor fixes and clean-up

This commit is contained in:
Toby Pohlen 2024-03-17 04:29:18 -07:00
parent ce81a1668e
commit e288ce35cb
6 changed files with 10 additions and 16 deletions

View File

@ -23,7 +23,6 @@ import re
import shutil import shutil
import tempfile import tempfile
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial
from typing import Any, Optional from typing import Any, Optional
import jax import jax
@ -143,7 +142,6 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1), mesh_config: tuple = (1, 1),
) -> Any: ) -> Any:
# re_pattern, re_substitue
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
@ -177,7 +175,6 @@ def restore(
checkpoint_path: str, checkpoint_path: str,
state_shapes: Any, state_shapes: Any,
mesh, mesh,
local_mesh_config,
between_hosts_config, between_hosts_config,
params_only, params_only,
state_sharding, state_sharding,

View File

@ -28,7 +28,6 @@ from jax.lax import with_sharding_constraint as pjit_sharding_constraint
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
config.update("jax_spmd_mode", "allow_all") config.update("jax_spmd_mode", "allow_all")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -226,7 +225,6 @@ class Router(hk.Module):
def compute_routing_prob( def compute_routing_prob(
self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int
): ):
"""During inference, we sort the inputs depending on which expert they get routed to."""
return self._compute_routing_prob(inputs, padding_mask, num_experts) return self._compute_routing_prob(inputs, padding_mask, num_experts)
@hk.transparent @hk.transparent
@ -1171,7 +1169,7 @@ class LanguageModelConfig:
# We cannot specify [] as a default value (it is mutable), hence None. # We cannot specify [] as a default value (it is mutable), hence None.
model_config = self.model model_config = self.model
assert self.init_scale_override is None, ( assert self.init_scale_override is None, (
"Overriding model " + "initialize scale is supported only for predefined models." "Overriding model initialize scale is supported only for predefined models."
) )
if self.model_size == 0: if self.model_size == 0:
self.model_size = model_config.emb_size self.model_size = model_config.emb_size
@ -1226,8 +1224,6 @@ class LanguageModel(hk.Module):
config = self.config config = self.config
input_mask = jnp.greater(tokens, config.pad_token) input_mask = jnp.greater(tokens, config.pad_token)
# Shift right by 1:
pad_width = ((0, 0), (1, 0))
# Embed the input tokens and positions. # Embed the input tokens and positions.
in_out_embed = InOutEmbed( in_out_embed = InOutEmbed(

View File

@ -8,7 +8,6 @@ ignore = [
"E731", "E731",
"E741", "E741",
"F405", "F405",
# Re-enable these later on:
"E402", "E402",
"F403", "F403",
] ]

View File

@ -1,4 +1,5 @@
dm_haiku==0.0.12 dm_haiku==0.0.12
jax==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]==0.4.25
numpy==1.26.4 numpy==1.26.4
sentencepiece==0.2.0 sentencepiece==0.2.0

8
run.py
View File

@ -14,12 +14,14 @@
import logging import logging
from model import LanguageModelConfig, TransformerConfig from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/" CKPT_PATH = "./checkpoints/"
# This is required in order for unpickling to work.
QuantizedWeight8bit = QW8Bit
def main(): def main():
grok_1_model = LanguageModelConfig( grok_1_model = LanguageModelConfig(

View File

@ -44,7 +44,6 @@ from model import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rank_logger = logging.getLogger("rank") rank_logger = logging.getLogger("rank")
TOP_K = 8 TOP_K = 8
@ -71,7 +70,8 @@ def insert_slice(memory: Memory, slice, length, i):
], ],
) )
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), memory, slice) return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
memory, slice)
def pad_to_size(x, size): def pad_to_size(x, size):
@ -239,7 +239,6 @@ class ModelRunner:
checkpoint_path=self.checkpoint_path, checkpoint_path=self.checkpoint_path,
state_shapes=state_shapes, state_shapes=state_shapes,
mesh=self.mesh, mesh=self.mesh,
local_mesh_config=self.local_mesh_config,
between_hosts_config=self.between_hosts_config, between_hosts_config=self.between_hosts_config,
state_sharding=self.state_sharding, state_sharding=self.state_sharding,
init_state=init_state, init_state=init_state,