mirror of
https://github.com/xai-org/grok-1.git
synced 2024-12-26 03:29:52 +03:00
Added PyTorch inference code.
This commit is contained in:
parent
7050ed204b
commit
bd65964f1a
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
checkpoints/*
|
||||
!checkpoints/README.md
|
||||
pytorch/grok-1-*
|
||||
|
67
pytorch/configuration_grok_1.py
Normal file
67
pytorch/configuration_grok_1.py
Normal file
@ -0,0 +1,67 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
# Copied from huggingface/transformers configuration_mixtral.py.
|
||||
# Modified to default values provided by xai-org/grok-1 run.py.
|
||||
class Grok1Config(PretrainedConfig):
|
||||
model_type = "grok-1"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
max_position_embeddings=8192,
|
||||
output_multiplier_scale=0.5773502691896257,
|
||||
embedding_multiplier_scale=78.38367176906169,
|
||||
hidden_size=6144,
|
||||
intermediate_size=32768,
|
||||
num_hidden_layers=64,
|
||||
num_attention_heads=48,
|
||||
attn_output_multiplier=0.08838834764831845,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=int(1e4),
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=2,
|
||||
num_local_experts=8,
|
||||
output_router_logits=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.output_multiplier_scale = output_multiplier_scale,
|
||||
self.embedding_multiplier_scale = embedding_multiplier_scale
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attn_output_multiplier = attn_output_multiplier
|
||||
|
||||
# For backward compatibility.
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_local_experts = num_local_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
401
pytorch/modeling_grok_1.py
Normal file
401
pytorch/modeling_grok_1.py
Normal file
@ -0,0 +1,401 @@
|
||||
from configuration_grok_1 import Grok1Config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, NamedTuple, List
|
||||
|
||||
class TiedWeightEmbedding(nn.Embedding):
|
||||
"""Module for tied weight embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Grok1Config,
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
padding_idx=config.pad_token_id,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.matmul(inputs, self.weight.T)
|
||||
|
||||
class Gating(nn.Module):
|
||||
"""Gating module for spare MoE expert selection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Grok1Config,
|
||||
):
|
||||
super().__init__()
|
||||
self.router_weights = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
routing_logits = self.router_weights(inputs)
|
||||
routing_probs = F.softmax(routing_logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
if padding_mask is not None:
|
||||
# [batch * seq, expert]
|
||||
routing_probs = routing_probs * padding_mask.view(-1).unsqueeze(-1)
|
||||
|
||||
# Note routing_probs is using float32.
|
||||
return routing_probs, routing_logits
|
||||
|
||||
class MLPExpert(nn.Module):
|
||||
"""MLP expert module for sparse MoE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Grok1Config,
|
||||
):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.v = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
h_w1 = self.act_fn(self.w1(inputs))
|
||||
h_v = self.v(inputs)
|
||||
h_dense = self.dense(h_w1 * h_v)
|
||||
return h_dense
|
||||
|
||||
class SparseMoEMLP(nn.Module):
|
||||
"""Sparse MoE MLP module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Grok1Config,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.gating = Gating(config)
|
||||
self.experts = nn.ModuleList([MLPExpert(config) for _ in range(self.num_experts)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Get routing probabilities and selected experts.
|
||||
routing_probs, routing_logits = self.gating(hidden_states, padding_mask)
|
||||
routing_probs, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
||||
routing_probs = routing_probs / routing_probs.sum(dim=-1, keepdim=True)
|
||||
# Now routing_probs is using the hidden_states' dtype instead of float32.
|
||||
routing_probs = routing_probs.to(hidden_states.dtype)
|
||||
|
||||
# Initialize output hidden states.
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# Create expert mask.
|
||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over experts and compute their contributions.
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
|
||||
top_x_list = top_x.tolist()
|
||||
idx_list = idx.tolist()
|
||||
|
||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_probs[top_x_list, idx_list, None]
|
||||
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, routing_logits
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_exponent: int = int(1e4),
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base_exponent = base_exponent
|
||||
assert self.dim % 2 == 0, "Embedding dimension must be even for rotary embeddings."
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
seq_dim: int,
|
||||
offset: torch.Tensor,
|
||||
const_position: Optional[int] = None,
|
||||
t: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Compute the per-dimension frequencies.
|
||||
dtype = x.dtype
|
||||
exponents = torch.arange(0, self.dim, 2, dtype=torch.float32, device=x.device)
|
||||
inv_freq = (self.base_exponent ** (exponents / self.dim)).reciprocal()
|
||||
|
||||
if not isinstance(offset, torch.Tensor):
|
||||
offset = torch.tensor(offset, dtype=torch.float32, device=x.device)
|
||||
if offset.dim() == 0:
|
||||
# Offset can be a scalar or one offset per batch element.
|
||||
offset = offset.unsqueeze(0)
|
||||
|
||||
# Compute the per-element phase (to pass into sin and cos).
|
||||
if const_position is not None:
|
||||
t = const_position * torch.ones(
|
||||
(1, x.shape[seq_dim]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
elif t is None:
|
||||
t = torch.arange(x.shape[seq_dim], dtype=torch.float32, device=x.device)
|
||||
t = t.unsqueeze(0) + offset.unsqueeze(-1)
|
||||
|
||||
phase = torch.einsum("bi,j->bij", t, inv_freq)
|
||||
phase = torch.cat([phase, phase], dim=-1)[:, :, None, :]
|
||||
|
||||
x_rotated = x * phase.cos() + rotate_half(x) * phase.sin()
|
||||
|
||||
return x_rotated.to(dtype)
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class KVMemory(NamedTuple):
|
||||
k: Optional[torch.Tensor]
|
||||
v: Optional[torch.Tensor]
|
||||
step: Optional[torch.Tensor]
|
||||
|
||||
def init_layer_memories(
|
||||
batch_size: int,
|
||||
sequence_len: int,
|
||||
num_kv_heads: int,
|
||||
key_size: int,
|
||||
num_layers: int,
|
||||
step: Optional[torch.Tensor] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
):
|
||||
if step is None:
|
||||
step = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
return [
|
||||
KVMemory(
|
||||
k=torch.zeros(batch_size, sequence_len, num_kv_heads, key_size, dtype=dtype, device=device),
|
||||
v=torch.zeros(batch_size, sequence_len, num_kv_heads, key_size, dtype=dtype, device=device),
|
||||
step=step,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: Grok1Config):
|
||||
super().__init__()
|
||||
self.num_q_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.key_size = config.hidden_size // config.num_attention_heads
|
||||
self.value_size = self.key_size
|
||||
self.attn_output_multiplier = config.attn_output_multiplier
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_q_heads * self.key_size, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.key_size, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.value_size, bias=False)
|
||||
self.out_proj = nn.Linear(self.num_q_heads * self.value_size, config.hidden_size, bias=False)
|
||||
self.rotary_pos_emb = RotaryPositionalEmbedding(self.key_size, base_exponent=config.rope_theta)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
layer_memory: Optional[KVMemory] = None,
|
||||
):
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
query = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_q_heads, self.key_size)
|
||||
key = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.key_size)
|
||||
value = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.value_size)
|
||||
|
||||
query = self.rotary_pos_emb(query, seq_dim=1, offset=layer_memory.step if layer_memory else 0)
|
||||
key = self.rotary_pos_emb(key, seq_dim=1, offset=layer_memory.step if layer_memory else 0)
|
||||
|
||||
if layer_memory:
|
||||
key = torch.cat([layer_memory.k, key], dim=1)
|
||||
value = torch.cat([layer_memory.v, value], dim=1)
|
||||
new_step = layer_memory.step + seq_len
|
||||
memory_mask = torch.arange(key.shape[1], device=key.device) < new_step[:, None]
|
||||
memory_mask = memory_mask[:, None, None, :]
|
||||
if mask is not None:
|
||||
mask = mask * memory_mask
|
||||
else:
|
||||
mask = memory_mask
|
||||
new_memory = KVMemory(k=key, v=value, step=new_step)
|
||||
else:
|
||||
new_memory = None
|
||||
|
||||
query = query.view(batch_size, seq_len, self.num_kv_heads, self.num_q_heads // self.num_kv_heads, self.key_size)
|
||||
attn_logits = torch.einsum("...thHd,...Thd->...hHtT", query, key).to(torch.float32)
|
||||
attn_logits *= self.attn_output_multiplier
|
||||
max_attn_val = torch.tensor(30.0, dtype=attn_logits.dtype, device=attn_logits.device)
|
||||
attn_logits = max_attn_val * torch.tanh(attn_logits / max_attn_val)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask[:, :, None, :, :]
|
||||
attn_logits = torch.where(mask, attn_logits, torch.full_like(attn_logits, float("-inf")))
|
||||
|
||||
attn_weights = F.softmax(attn_logits, dim=-1).to(query.dtype)
|
||||
|
||||
attn = torch.einsum("...hHtT,...Thd->...thHd", attn_weights, value)
|
||||
attn = attn.reshape(batch_size, seq_len, -1)
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn, new_memory
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self, config: Grok1Config):
|
||||
super().__init__()
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.norm1 = RMSNorm(config.hidden_size)
|
||||
self.norm2 = RMSNorm(config.hidden_size)
|
||||
self.norm3 = RMSNorm(config.hidden_size)
|
||||
self.norm4 = RMSNorm(config.hidden_size)
|
||||
|
||||
if config.num_local_experts > 1:
|
||||
self.mlp = SparseMoEMLP(config)
|
||||
else:
|
||||
self.mlp = MLPExpert(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
layer_memory: Optional[KVMemory] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
attn_output, new_memory = self.attention(hidden_states, mask, layer_memory)
|
||||
attn_output = self.norm2(attn_output)
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm3(hidden_states)
|
||||
if isinstance(self.mlp, SparseMoEMLP):
|
||||
mlp_output, routing_logits = self.mlp(hidden_states, padding_mask)
|
||||
else:
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
routing_logits = None
|
||||
mlp_output = self.norm4(mlp_output)
|
||||
hidden_states = residual + mlp_output
|
||||
|
||||
return hidden_states, new_memory, routing_logits
|
||||
|
||||
class Grok1PreTrainedModel(PreTrainedModel):
|
||||
config_class = Grok1Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Decoder"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
class Grok1ForCausalLM(Grok1PreTrainedModel):
|
||||
|
||||
def __init__(self, config: Grok1Config):
|
||||
super().__init__(config)
|
||||
self.embedding = TiedWeightEmbedding(self.config)
|
||||
self.layers = nn.ModuleList([Decoder(self.config) for _ in range(self.config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(self.config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
memory: Optional[List[KVMemory]] = None,
|
||||
last_hid_only: bool = False,
|
||||
length: Optional[torch.Tensor] = None,
|
||||
):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||
else:
|
||||
attention_mask = attention_mask.bool()
|
||||
|
||||
padding_mask = attention_mask.view(batch_size, seq_len)
|
||||
causal_mask = torch.tril(torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=input_ids.device))
|
||||
mask = padding_mask[:, None, None, :] * causal_mask
|
||||
|
||||
hidden_states = self.embedding(input_ids) * self.config.embedding_multiplier_scale
|
||||
kv_memories = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer_memory = memory[i] if memory else None
|
||||
hidden_states, new_memory, routing_logits = layer(
|
||||
hidden_states,
|
||||
mask,
|
||||
padding_mask,
|
||||
layer_memory,
|
||||
)
|
||||
kv_memories.append(new_memory)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if last_hid_only:
|
||||
last_step = torch.maximum(torch.sum(padding_mask, dim=1) - 1, torch.tensor(0, device=input_ids.device))
|
||||
hidden_states = hidden_states[torch.arange(batch_size, device=input_ids.device), last_step]
|
||||
elif length is not None:
|
||||
last_step = torch.maximum(length - 1, torch.tensor(0, device=input_ids.device))
|
||||
hidden_states = hidden_states[torch.arange(batch_size, device=input_ids.device), last_step]
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
logits = self.embedding.decode(hidden_states) * torch.tensor(self.config.output_multiplier_scale, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
return logits, kv_memories
|
Loading…
Reference in New Issue
Block a user