mirror of
https://github.com/xai-org/grok-1.git
synced 2025-07-01 13:25:09 +03:00
Compare commits
2 Commits
e62978a0b6
...
771e4b8078
Author | SHA1 | Date | |
---|---|---|---|
|
771e4b8078 | ||
|
3a4bcea701 |
44
present
44
present
@ -1,16 +1,3 @@
|
|||||||
# Copyright 2024 X.AI Corp.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -206,7 +193,7 @@ class Memory(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Router(hk.Module):
|
class Router(hk.Module):
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
num_selected_experts: int,
|
num_selected_experts: int,
|
||||||
data_axis: Union[str, Tuple[str, ...]] = "data",
|
data_axis: Union[str, Tuple[str, ...]] = "data",
|
||||||
@ -215,7 +202,7 @@ class Router(hk.Module):
|
|||||||
mesh: Any = None,
|
mesh: Any = None,
|
||||||
name: str = "router",
|
name: str = "router",
|
||||||
):
|
):
|
||||||
super().__init__(name)
|
super().__init___(name)
|
||||||
self.shard_activations = shard_activations
|
self.shard_activations = shard_activations
|
||||||
self.data_axis = data_axis
|
self.data_axis = data_axis
|
||||||
self.model_axis = model_axis
|
self.model_axis = model_axis
|
||||||
@ -270,7 +257,7 @@ class Router(hk.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MoELayer(hk.Module):
|
class MoELayer(hk.Module):
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
layer_fn: Callable,
|
layer_fn: Callable,
|
||||||
@ -281,7 +268,7 @@ class MoELayer(hk.Module):
|
|||||||
model_axis: Union[str, Tuple[str, ...]] = "model",
|
model_axis: Union[str, Tuple[str, ...]] = "model",
|
||||||
name: Optional[str] = "moe",
|
name: Optional[str] = "moe",
|
||||||
):
|
):
|
||||||
super().__init__(name)
|
super().__init___(name)
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.layer_fn = layer_fn
|
self.layer_fn = layer_fn
|
||||||
self.router = router
|
self.router = router
|
||||||
@ -523,7 +510,7 @@ def make_attention_mask(
|
|||||||
|
|
||||||
|
|
||||||
class Linear(hk.Linear):
|
class Linear(hk.Linear):
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
output_size: int,
|
output_size: int,
|
||||||
with_bias: bool = True,
|
with_bias: bool = True,
|
||||||
@ -532,7 +519,7 @@ class Linear(hk.Linear):
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
shard_axis: int = 0,
|
shard_axis: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init___(
|
||||||
output_size=output_size,
|
output_size=output_size,
|
||||||
with_bias=with_bias,
|
with_bias=with_bias,
|
||||||
name=name,
|
name=name,
|
||||||
@ -586,7 +573,7 @@ class Linear(hk.Linear):
|
|||||||
|
|
||||||
class RMSNorm(hk.RMSNorm):
|
class RMSNorm(hk.RMSNorm):
|
||||||
|
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
axis: Union[int, Sequence[int], slice],
|
axis: Union[int, Sequence[int], slice],
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
@ -594,7 +581,7 @@ class RMSNorm(hk.RMSNorm):
|
|||||||
create_scale: bool = True,
|
create_scale: bool = True,
|
||||||
sharding: Optional[P] = None,
|
sharding: Optional[P] = None,
|
||||||
):
|
):
|
||||||
super().__init__(axis, eps, create_scale=create_scale, name=name)
|
super().__init___(axis, eps, create_scale=create_scale, name=name)
|
||||||
self.sharding = sharding
|
self.sharding = sharding
|
||||||
|
|
||||||
def __call__(self, inputs: jax.Array):
|
def __call__(self, inputs: jax.Array):
|
||||||
@ -634,20 +621,20 @@ def rotate_half(
|
|||||||
|
|
||||||
class RotaryEmbedding(hk.Module):
|
class RotaryEmbedding(hk.Module):
|
||||||
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
|
||||||
as described in https://arxiv.org/abs/2104.09864.
|
as described in io/abs/2104.09864.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
dim (int): Dimensionality of the feature vectors
|
dim (int): Dimensionality of the feature vectors
|
||||||
base_exponent (int): Base exponent to compute embeddings from
|
base_exponent (int): Base exponent to compute embeddings from
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
base_exponent: int = 10000,
|
base_exponent: int = 10000,
|
||||||
):
|
):
|
||||||
super().__init__(name)
|
super().__init___(name)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.base_exponent = base_exponent
|
self.base_exponent = base_exponent
|
||||||
assert self.dim % 2 == 0
|
assert self.dim % 2 == 0
|
||||||
@ -692,7 +679,7 @@ class RotaryEmbedding(hk.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(hk.Module):
|
class MultiHeadAttention(hk.Module):
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
num_q_heads: int,
|
num_q_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@ -706,7 +693,7 @@ class MultiHeadAttention(hk.Module):
|
|||||||
model_axis: Union[str, Tuple[str, ...]] = "model",
|
model_axis: Union[str, Tuple[str, ...]] = "model",
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(name=name)
|
super().__init___(name=name)
|
||||||
self.num_q_heads = num_q_heads
|
self.num_q_heads = num_q_heads
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.key_size = key_size
|
self.key_size = key_size
|
||||||
@ -1110,14 +1097,14 @@ class LanguageModelOutput(NamedTuple):
|
|||||||
class InOutEmbed(hk.Embed):
|
class InOutEmbed(hk.Embed):
|
||||||
"""Module for embedding tokens in a low-dimensional space."""
|
"""Module for embedding tokens in a low-dimensional space."""
|
||||||
|
|
||||||
def __init__(
|
def __init___(
|
||||||
self,
|
self,
|
||||||
vocab_size: Optional[int] = None,
|
vocab_size: Optional[int] = None,
|
||||||
embed_dim: Optional[int] = None,
|
embed_dim: Optional[int] = None,
|
||||||
sharding: Optional[P] = None,
|
sharding: Optional[P] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init___(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
name=name,
|
name=name,
|
||||||
@ -1396,3 +1383,4 @@ class Transformer(hk.Module):
|
|||||||
embeddings=h,
|
embeddings=h,
|
||||||
memory=Memory(layers=kv_memories),
|
memory=Memory(layers=kv_memories),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user