mirror of
https://github.com/xai-org/grok-1.git
synced 2025-04-03 18:00:10 +03:00
Minor fixes and clean-up
This commit is contained in:
parent
ce81a1668e
commit
e288ce35cb
@ -23,7 +23,6 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import jax
|
||||
@ -143,7 +142,6 @@ def replace_with_load_state(
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
mesh_config: tuple = (1, 1),
|
||||
) -> Any:
|
||||
# re_pattern, re_substitue
|
||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||
@ -177,7 +175,6 @@ def restore(
|
||||
checkpoint_path: str,
|
||||
state_shapes: Any,
|
||||
mesh,
|
||||
local_mesh_config,
|
||||
between_hosts_config,
|
||||
params_only,
|
||||
state_sharding,
|
||||
|
6
model.py
6
model.py
@ -28,7 +28,6 @@ from jax.lax import with_sharding_constraint as pjit_sharding_constraint
|
||||
from jax.sharding import PartitionSpec
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
config.update("jax_spmd_mode", "allow_all")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -226,7 +225,6 @@ class Router(hk.Module):
|
||||
def compute_routing_prob(
|
||||
self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int
|
||||
):
|
||||
"""During inference, we sort the inputs depending on which expert they get routed to."""
|
||||
return self._compute_routing_prob(inputs, padding_mask, num_experts)
|
||||
|
||||
@hk.transparent
|
||||
@ -1171,7 +1169,7 @@ class LanguageModelConfig:
|
||||
# We cannot specify [] as a default value (it is mutable), hence None.
|
||||
model_config = self.model
|
||||
assert self.init_scale_override is None, (
|
||||
"Overriding model " + "initialize scale is supported only for predefined models."
|
||||
"Overriding model initialize scale is supported only for predefined models."
|
||||
)
|
||||
if self.model_size == 0:
|
||||
self.model_size = model_config.emb_size
|
||||
@ -1226,8 +1224,6 @@ class LanguageModel(hk.Module):
|
||||
config = self.config
|
||||
|
||||
input_mask = jnp.greater(tokens, config.pad_token)
|
||||
# Shift right by 1:
|
||||
pad_width = ((0, 0), (1, 0))
|
||||
|
||||
# Embed the input tokens and positions.
|
||||
in_out_embed = InOutEmbed(
|
||||
|
@ -8,7 +8,6 @@ ignore = [
|
||||
"E731",
|
||||
"E741",
|
||||
"F405",
|
||||
# Re-enable these later on:
|
||||
"E402",
|
||||
"F403",
|
||||
]
|
||||
|
@ -1,4 +1,5 @@
|
||||
dm_haiku==0.0.12
|
||||
jax==0.4.25
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
jax[cuda12_pip]==0.4.25
|
||||
numpy==1.26.4
|
||||
sentencepiece==0.2.0
|
||||
|
8
run.py
8
run.py
@ -14,12 +14,14 @@
|
||||
|
||||
import logging
|
||||
|
||||
from model import LanguageModelConfig, TransformerConfig
|
||||
from runners import InferenceRunner, ModelRunner
|
||||
|
||||
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
||||
from runners import InferenceRunner, ModelRunner, sample_from_model
|
||||
|
||||
CKPT_PATH = "./checkpoints/"
|
||||
|
||||
# This is required in order for unpickling to work.
|
||||
QuantizedWeight8bit = QW8Bit
|
||||
|
||||
|
||||
def main():
|
||||
grok_1_model = LanguageModelConfig(
|
||||
|
@ -44,7 +44,6 @@ from model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
rank_logger = logging.getLogger("rank")
|
||||
|
||||
|
||||
TOP_K = 8
|
||||
|
||||
|
||||
@ -71,7 +70,8 @@ def insert_slice(memory: Memory, slice, length, i):
|
||||
],
|
||||
)
|
||||
|
||||
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), memory, slice)
|
||||
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
|
||||
memory, slice)
|
||||
|
||||
|
||||
def pad_to_size(x, size):
|
||||
@ -239,7 +239,6 @@ class ModelRunner:
|
||||
checkpoint_path=self.checkpoint_path,
|
||||
state_shapes=state_shapes,
|
||||
mesh=self.mesh,
|
||||
local_mesh_config=self.local_mesh_config,
|
||||
between_hosts_config=self.between_hosts_config,
|
||||
state_sharding=self.state_sharding,
|
||||
init_state=init_state,
|
||||
|
Loading…
Reference in New Issue
Block a user