mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-23 03:59:53 +03:00
Merge 1a0ba385eb
into 7050ed204b
This commit is contained in:
commit
03df01959d
41
model.py
41
model.py
@ -1002,7 +1002,7 @@ class DenseBlock(hk.Module):
|
||||
sharding=P("model", "data"),
|
||||
mesh=self.mesh,
|
||||
shard_axis=1,
|
||||
)(h_w1 * h_v)
|
||||
)(h_w1 * h_v) # TODO: Document why this isn't sequential and whether it should be.
|
||||
|
||||
return h_dense
|
||||
|
||||
@ -1036,13 +1036,10 @@ class DecoderLayer(hk.Module):
|
||||
) -> DecoderOutput:
|
||||
"""Transforms input embedding sequences to output embedding sequences."""
|
||||
|
||||
def layer_norm(x):
|
||||
return hk_rms_norm(x)
|
||||
|
||||
sharding = P(self.data_axis, None)
|
||||
if self.shard_activations:
|
||||
sharding = P(self.data_axis, None, self.model_axis)
|
||||
else:
|
||||
sharding = P(self.data_axis, None)
|
||||
|
||||
h = with_sharding_constraint(inputs, sharding)
|
||||
|
||||
attn_output = MHABlock(
|
||||
@ -1054,8 +1051,8 @@ class DecoderLayer(hk.Module):
|
||||
data_axis=self.data_axis,
|
||||
model_axis=self.model_axis,
|
||||
)(layer_norm(h), mask, layer_memory)
|
||||
h_attn = attn_output.embeddings
|
||||
|
||||
h_attn = attn_output.embeddings
|
||||
h_attn = layer_norm(h_attn)
|
||||
h += h_attn
|
||||
h = with_sharding_constraint(h, sharding)
|
||||
@ -1165,15 +1162,17 @@ class LanguageModelConfig:
|
||||
_initialized = False
|
||||
|
||||
def initialize(self):
|
||||
# We cannot specify [] as a default value (it is mutable), hence None.
|
||||
model_config = self.model
|
||||
assert self.init_scale_override is None, (
|
||||
"Overriding model initialize scale is supported only for predefined models."
|
||||
)
|
||||
|
||||
if self.model is None: # We cannot specify [] as a default value (it is mutable), hence None.
|
||||
raise ValueError("Model configuration is not set.")
|
||||
if self.init_scale_override is not None:
|
||||
raise ValueError("Overriding model initialize scale is supported only for predefined models.")
|
||||
|
||||
if self.model_size == 0:
|
||||
self.model_size = model_config.emb_size
|
||||
assert self.model is not None, "Model could not be initialized."
|
||||
self._initialized = True
|
||||
|
||||
return self
|
||||
|
||||
def make(self, *args, **kwargs):
|
||||
@ -1194,7 +1193,7 @@ class LanguageModelConfig:
|
||||
return LM_PARTITION_RULES + self.model.partition_rules()
|
||||
|
||||
|
||||
def layer_norm(x, model):
|
||||
def layer_norm(x):
|
||||
return hk_rms_norm(x)
|
||||
|
||||
|
||||
@ -1213,17 +1212,12 @@ class LanguageModel(hk.Module):
|
||||
tokens: jax.Array,
|
||||
memory: Optional[Memory] = None,
|
||||
*,
|
||||
batch: Dict[str, jax.Array] = {},
|
||||
last_hid_only: bool = False,
|
||||
length: Optional[jax.Array] = None,
|
||||
) -> LanguageModelOutput:
|
||||
"""Forward pass, producing a sequence of logits."""
|
||||
del batch # Unused.
|
||||
|
||||
config = self.config
|
||||
|
||||
input_mask = jnp.greater(tokens, config.pad_token)
|
||||
|
||||
# Embed the input tokens and positions.
|
||||
in_out_embed = InOutEmbed(
|
||||
self.config.vocab_size,
|
||||
@ -1235,6 +1229,7 @@ class LanguageModel(hk.Module):
|
||||
input_embeddings, P("data", None, self.model.model_axis)
|
||||
)
|
||||
input_embeddings *= config.embedding_multiplier_scale
|
||||
input_mask = jnp.not_equal(tokens, config.pad_token)
|
||||
|
||||
model_output = self.model(
|
||||
input_embeddings,
|
||||
@ -1242,15 +1237,15 @@ class LanguageModel(hk.Module):
|
||||
memory=memory,
|
||||
) # [B, T, D]
|
||||
embeddings, model_state = model_output.embeddings, model_output.memory
|
||||
if embeddings.dtype != self.fprop_dtype:
|
||||
raise ValueError(f"Expected forward propagation dtype {self.fprop_dtype} but got {embeddings.dtype} in embeddings.")
|
||||
|
||||
if self.model.shard_activations:
|
||||
embeddings = with_sharding_constraint(
|
||||
embeddings, P("data", None, self.model.model_axis)
|
||||
)
|
||||
embeddings = with_sharding_constraint(embeddings, P("data", None, self.model.model_axis))
|
||||
else:
|
||||
embeddings = with_sharding_constraint(embeddings, P("data", None))
|
||||
rank_logger.debug(f"Final embedding shape: {embeddings.shape}")
|
||||
embeddings = layer_norm(embeddings, self.model)
|
||||
assert embeddings.dtype == self.fprop_dtype
|
||||
embeddings = layer_norm(embeddings)
|
||||
|
||||
if last_hid_only:
|
||||
last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user