From 1a0ba385eb824f0a1030fe7f4cbfc625f1addd2d Mon Sep 17 00:00:00 2001 From: devindkim Date: Mon, 18 Mar 2024 18:40:37 -0400 Subject: [PATCH] Add exceptions to LanguageModel --- model.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/model.py b/model.py index 87d700d..8e5fd96 100644 --- a/model.py +++ b/model.py @@ -1002,7 +1002,7 @@ class DenseBlock(hk.Module): sharding=P("model", "data"), mesh=self.mesh, shard_axis=1, - )(h_w1 * h_v) + )(h_w1 * h_v) # TODO: Document why this isn't sequential and whether it should be. return h_dense @@ -1036,13 +1036,10 @@ class DecoderLayer(hk.Module): ) -> DecoderOutput: """Transforms input embedding sequences to output embedding sequences.""" - def layer_norm(x): - return hk_rms_norm(x) - + sharding = P(self.data_axis, None) if self.shard_activations: sharding = P(self.data_axis, None, self.model_axis) - else: - sharding = P(self.data_axis, None) + h = with_sharding_constraint(inputs, sharding) attn_output = MHABlock( @@ -1054,8 +1051,8 @@ class DecoderLayer(hk.Module): data_axis=self.data_axis, model_axis=self.model_axis, )(layer_norm(h), mask, layer_memory) - h_attn = attn_output.embeddings + h_attn = attn_output.embeddings h_attn = layer_norm(h_attn) h += h_attn h = with_sharding_constraint(h, sharding) @@ -1165,15 +1162,17 @@ class LanguageModelConfig: _initialized = False def initialize(self): - # We cannot specify [] as a default value (it is mutable), hence None. model_config = self.model - assert self.init_scale_override is None, ( - "Overriding model initialize scale is supported only for predefined models." - ) + + if self.model is None: # We cannot specify [] as a default value (it is mutable), hence None. + raise ValueError("Model configuration is not set.") + if self.init_scale_override is not None: + raise ValueError("Overriding model initialize scale is supported only for predefined models.") + if self.model_size == 0: self.model_size = model_config.emb_size - assert self.model is not None, "Model could not be initialized." self._initialized = True + return self def make(self, *args, **kwargs): @@ -1194,7 +1193,7 @@ class LanguageModelConfig: return LM_PARTITION_RULES + self.model.partition_rules() -def layer_norm(x, model): +def layer_norm(x): return hk_rms_norm(x) @@ -1213,17 +1212,12 @@ class LanguageModel(hk.Module): tokens: jax.Array, memory: Optional[Memory] = None, *, - batch: Dict[str, jax.Array] = {}, last_hid_only: bool = False, length: Optional[jax.Array] = None, ) -> LanguageModelOutput: """Forward pass, producing a sequence of logits.""" - del batch # Unused. - config = self.config - input_mask = jnp.greater(tokens, config.pad_token) - # Embed the input tokens and positions. in_out_embed = InOutEmbed( self.config.vocab_size, @@ -1235,6 +1229,7 @@ class LanguageModel(hk.Module): input_embeddings, P("data", None, self.model.model_axis) ) input_embeddings *= config.embedding_multiplier_scale + input_mask = jnp.not_equal(tokens, config.pad_token) model_output = self.model( input_embeddings, @@ -1242,15 +1237,15 @@ class LanguageModel(hk.Module): memory=memory, ) # [B, T, D] embeddings, model_state = model_output.embeddings, model_output.memory + if embeddings.dtype != self.fprop_dtype: + raise ValueError(f"Expected forward propagation dtype {self.fprop_dtype} but got {embeddings.dtype} in embeddings.") + if self.model.shard_activations: - embeddings = with_sharding_constraint( - embeddings, P("data", None, self.model.model_axis) - ) + embeddings = with_sharding_constraint(embeddings, P("data", None, self.model.model_axis)) else: embeddings = with_sharding_constraint(embeddings, P("data", None)) rank_logger.debug(f"Final embedding shape: {embeddings.shape}") - embeddings = layer_norm(embeddings, self.model) - assert embeddings.dtype == self.fprop_dtype + embeddings = layer_norm(embeddings) if last_hid_only: last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)