Compare commits

..

1 Commits

Author SHA1 Message Date
JosefaOrtiz
e62978a0b6
Merge 9f0aaff9fd into 7050ed204b 2025-05-29 04:13:13 +00:00

44
present
View File

@ -1,3 +1,16 @@
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools import functools
import logging import logging
@ -193,7 +206,7 @@ class Memory(NamedTuple):
class Router(hk.Module): class Router(hk.Module):
def __init___( def __init__(
self, self,
num_selected_experts: int, num_selected_experts: int,
data_axis: Union[str, Tuple[str, ...]] = "data", data_axis: Union[str, Tuple[str, ...]] = "data",
@ -202,7 +215,7 @@ class Router(hk.Module):
mesh: Any = None, mesh: Any = None,
name: str = "router", name: str = "router",
): ):
super().__init___(name) super().__init__(name)
self.shard_activations = shard_activations self.shard_activations = shard_activations
self.data_axis = data_axis self.data_axis = data_axis
self.model_axis = model_axis self.model_axis = model_axis
@ -257,7 +270,7 @@ class Router(hk.Module):
class MoELayer(hk.Module): class MoELayer(hk.Module):
def __init___( def __init__(
self, self,
num_experts: int, num_experts: int,
layer_fn: Callable, layer_fn: Callable,
@ -268,7 +281,7 @@ class MoELayer(hk.Module):
model_axis: Union[str, Tuple[str, ...]] = "model", model_axis: Union[str, Tuple[str, ...]] = "model",
name: Optional[str] = "moe", name: Optional[str] = "moe",
): ):
super().__init___(name) super().__init__(name)
self.num_experts = num_experts self.num_experts = num_experts
self.layer_fn = layer_fn self.layer_fn = layer_fn
self.router = router self.router = router
@ -510,7 +523,7 @@ def make_attention_mask(
class Linear(hk.Linear): class Linear(hk.Linear):
def __init___( def __init__(
self, self,
output_size: int, output_size: int,
with_bias: bool = True, with_bias: bool = True,
@ -519,7 +532,7 @@ class Linear(hk.Linear):
name: Optional[str] = None, name: Optional[str] = None,
shard_axis: int = 0, shard_axis: int = 0,
): ):
super().__init___( super().__init__(
output_size=output_size, output_size=output_size,
with_bias=with_bias, with_bias=with_bias,
name=name, name=name,
@ -573,7 +586,7 @@ class Linear(hk.Linear):
class RMSNorm(hk.RMSNorm): class RMSNorm(hk.RMSNorm):
def __init___( def __init__(
self, self,
axis: Union[int, Sequence[int], slice], axis: Union[int, Sequence[int], slice],
eps: float = 1e-5, eps: float = 1e-5,
@ -581,7 +594,7 @@ class RMSNorm(hk.RMSNorm):
create_scale: bool = True, create_scale: bool = True,
sharding: Optional[P] = None, sharding: Optional[P] = None,
): ):
super().__init___(axis, eps, create_scale=create_scale, name=name) super().__init__(axis, eps, create_scale=create_scale, name=name)
self.sharding = sharding self.sharding = sharding
def __call__(self, inputs: jax.Array): def __call__(self, inputs: jax.Array):
@ -621,20 +634,20 @@ def rotate_half(
class RotaryEmbedding(hk.Module): class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor, """Applies rotary embeddings (RoPE) to the input sequence tensor,
as described in io/abs/2104.09864. as described in https://arxiv.org/abs/2104.09864.
Attributes: Attributes:
dim (int): Dimensionality of the feature vectors dim (int): Dimensionality of the feature vectors
base_exponent (int): Base exponent to compute embeddings from base_exponent (int): Base exponent to compute embeddings from
""" """
def __init___( def __init__(
self, self,
dim: int, dim: int,
name: Optional[str] = None, name: Optional[str] = None,
base_exponent: int = 10000, base_exponent: int = 10000,
): ):
super().__init___(name) super().__init__(name)
self.dim = dim self.dim = dim
self.base_exponent = base_exponent self.base_exponent = base_exponent
assert self.dim % 2 == 0 assert self.dim % 2 == 0
@ -679,7 +692,7 @@ class RotaryEmbedding(hk.Module):
class MultiHeadAttention(hk.Module): class MultiHeadAttention(hk.Module):
def __init___( def __init__(
self, self,
num_q_heads: int, num_q_heads: int,
num_kv_heads: int, num_kv_heads: int,
@ -693,7 +706,7 @@ class MultiHeadAttention(hk.Module):
model_axis: Union[str, Tuple[str, ...]] = "model", model_axis: Union[str, Tuple[str, ...]] = "model",
name: Optional[str] = None, name: Optional[str] = None,
): ):
super().__init___(name=name) super().__init__(name=name)
self.num_q_heads = num_q_heads self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.key_size = key_size self.key_size = key_size
@ -1097,14 +1110,14 @@ class LanguageModelOutput(NamedTuple):
class InOutEmbed(hk.Embed): class InOutEmbed(hk.Embed):
"""Module for embedding tokens in a low-dimensional space.""" """Module for embedding tokens in a low-dimensional space."""
def __init___( def __init__(
self, self,
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
embed_dim: Optional[int] = None, embed_dim: Optional[int] = None,
sharding: Optional[P] = None, sharding: Optional[P] = None,
name: Optional[str] = None, name: Optional[str] = None,
): ):
super().__init___( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
embed_dim=embed_dim, embed_dim=embed_dim,
name=name, name=name,
@ -1383,4 +1396,3 @@ class Transformer(hk.Module):
embeddings=h, embeddings=h,
memory=Memory(layers=kv_memories), memory=Memory(layers=kv_memories),
) )