mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 12:39:54 +03:00
310 lines
10 KiB
Python
310 lines
10 KiB
Python
# Copyright 2024 X.AI Corp.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import bisect
|
|
import functools
|
|
import logging
|
|
import math
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import haiku as hk
|
|
import sentencepiece
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, NamedTuple, Optional, Tuple
|
|
from jax.experimental import pjit
|
|
from jax.sharding import PartitionSpec as P
|
|
from model import (
|
|
LanguageModelConfig,
|
|
LanguageModelOutput,
|
|
TrainingState,
|
|
apply_rules,
|
|
Memory,
|
|
KVMemory,
|
|
)
|
|
import checkpoint as xai_checkpoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
rank_logger = logging.getLogger("rank")
|
|
rank_logger.setLevel(logging.INFO)
|
|
|
|
TOP_K = 8
|
|
|
|
class SampleSettings(NamedTuple):
|
|
temperature: jax.Array
|
|
nucleus_p: jax.Array
|
|
mask: jax.Array
|
|
active: jax.Array
|
|
|
|
class SampleOutput(NamedTuple):
|
|
token_id: jax.Array
|
|
prob: jax.Array
|
|
top_k_token_ids: jax.Array
|
|
top_k_probs: jax.Array
|
|
|
|
def insert_slice(memory: Memory, slice: Memory, length: int, i: int) -> Memory:
|
|
slice = Memory(
|
|
layers=[
|
|
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
|
for layer in slice.layers
|
|
],
|
|
)
|
|
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
|
|
memory, slice)
|
|
|
|
def pad_to_size(x: jnp.ndarray, size: int) -> jnp.ndarray:
|
|
if x.shape[0] > size:
|
|
x = x[-size:]
|
|
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
|
|
|
|
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
|
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
|
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
|
sorted_probs = jax.nn.softmax(sorted_logits)
|
|
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
|
|
threshold_largest_logits = jnp.take_along_axis(
|
|
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
|
|
)
|
|
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
|
|
mask = logits >= threshold_largest_logits
|
|
logits = jnp.where(mask, logits, -1e10)
|
|
return logits
|
|
|
|
def sample_token(
|
|
rngs: jax.random.PRNGKey,
|
|
lm_outputs: LanguageModelOutput,
|
|
settings: SampleSettings,
|
|
) -> SampleOutput:
|
|
settings = SampleSettings(
|
|
temperature=jnp.expand_dims(settings.temperature, (1, 2)),
|
|
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)),
|
|
mask=jnp.expand_dims(settings.mask, 1),
|
|
active=settings.active,
|
|
)
|
|
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
|
|
logits = jnp.where(settings.mask, logits, -1e10)
|
|
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
|
|
new_token = jax.vmap(jax.random.categorical)(rngs, logits)
|
|
probabilities = jax.nn.softmax(logits)
|
|
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
|
|
token_prob = jnp.squeeze(token_prob, 1)
|
|
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
|
|
top_k_probs = jnp.squeeze(top_k_probs, 1)
|
|
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
|
|
return SampleOutput(
|
|
new_token,
|
|
token_prob,
|
|
top_k_token_ids,
|
|
top_k_probs,
|
|
)
|
|
|
|
@dataclass
|
|
class ModelRunner:
|
|
model: LanguageModelConfig
|
|
bs_per_device: float = 2.0
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None
|
|
load_exclude_rules: Optional[list[str]] = None
|
|
rng_seed: int = 42
|
|
transform_forward: bool = False
|
|
checkpoint_path: str = ""
|
|
|
|
def make_forward_fn(self, mesh: Any):
|
|
def forward(tokens):
|
|
out = self.model.make(mesh=mesh)(tokens)
|
|
return out, None
|
|
|
|
if self.transform_forward:
|
|
forward = hk.transform(forward)
|
|
return forward
|
|
|
|
def initialize(
|
|
self,
|
|
init_data,
|
|
local_mesh_config: Tuple[int, int],
|
|
between_hosts_config: Tuple[int, int],
|
|
):
|
|
num_replicas = math.prod(between_hosts_config)
|
|
self.model.initialize()
|
|
self.model.fprop_dtype = jnp.bfloat16
|
|
num_local_gpus = len(jax.local_devices())
|
|
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
|
|
self.local_batch_size = self.batch_size // jax.process_count()
|
|
self.local_mesh_config = local_mesh_config
|
|
self.between_hosts_config = between_hosts_config
|
|
rank_logger.info(
|
|
f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
|
|
)
|
|
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
|
|
self.forward = self.make_forward_fn(mesh=self.mesh)
|
|
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
|
|
self.eval_forward = self.make_forward_fn(mesh=self.mesh)
|
|
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
|
|
|
|
if self.transform_forward:
|
|
self.state_sharding = self.get_state_sharding(init_data)
|
|
rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
|
|
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
|
|
|
def init(self, rng: jax.Array, data) -> TrainingState:
|
|
assert self.transform_forward
|
|
rng, init_rng = jax.random.split(rng)
|
|
params = self.forward.init(init_rng, data["inputs"])
|
|
return TrainingState(params=params)
|
|
|
|
def get_state_sharding(self, init_data):
|
|
assert self.transform_forward
|
|
rng = jax.random.PRNGKey(self.rng_seed)
|
|
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
|
|
|
with self.mesh:
|
|
shapes = jax.eval_shape(self.init, rng, init_data)
|
|
sharding = jax.tree_util.tree_map_with_path(
|
|
apply_rules(self.model.partition_rules()),
|
|
shapes,
|
|
)
|
|
return sharding
|
|
|
|
def load_or_init(
|
|
self,
|
|
init_data: Any,
|
|
from_checkpoint: bool = True,
|
|
init_fn: Optional[Callable] = None,
|
|
):
|
|
rng = jax.random.PRNGKey(self.rng_seed)
|
|
|
|
if not self.checkpoint_path or not from_checkpoint:
|
|
rank_logger.info("Initializing model...")
|
|
with self.mesh:
|
|
if init_fn is not None:
|
|
state = init_fn(rng, init_data)
|
|
else:
|
|
assert self.transform_forward
|
|
state = self.init_fn(rng, init_data)
|
|
rank_logger.info("Model state is newly initialized.")
|
|
else:
|
|
with self.mesh:
|
|
if init_fn:
|
|
state_shapes = jax.eval_shape(init_fn, rng, init_data)
|
|
else:
|
|
assert self.transform_forward
|
|
state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
|
|
init_state = None
|
|
state = xai_checkpoint.restore(
|
|
checkpoint_path=self.checkpoint_path,
|
|
state_shapes=state_shapes,
|
|
mesh=self.mesh,
|
|
between_hosts_config=self.between_hosts_config,
|
|
state_sharding=self.state_sharding,
|
|
init_state=init_state,
|
|
params_only=True,
|
|
)
|
|
del init_state
|
|
return state
|
|
|
|
@dataclass
|
|
class Request:
|
|
prompt: str
|
|
temperature: float
|
|
nucleus_p: float
|
|
rng_seed: int
|
|
max_len: int
|
|
|
|
@dataclass
|
|
class InferenceRunner:
|
|
name: str
|
|
runner: ModelRunner
|
|
load: str
|
|
tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
|
|
local_mesh_config: Tuple[int, int] = (1, 1)
|
|
between_hosts_config: Tuple[int, int] = (1, 1)
|
|
pad_sizes: Tuple[int] = (1024,)
|
|
|
|
def get_pad_bucket(self, size: int) -> int:
|
|
i = bisect.bisect_left(self.pad_sizes, size)
|
|
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
|
|
|
def initialize(self):
|
|
runner = self.runner
|
|
self.runner.transform_forward = True
|
|
dummy_data = dict(
|
|
inputs=np.zeros((1, self.get_pad_bucket(512)), dtype=np.int32),
|
|
)
|
|
state = runner.load_or_init(
|
|
dummy_data,
|
|
from_checkpoint=False,
|
|
)
|
|
runner.params = state.params
|
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
|
|
def text_to_token_ids(text):
|
|
ids = self.tokenizer.encode(text, out_type=int)
|
|
return ids
|
|
|
|
self.text_to_token_ids = text_to_token_ids
|
|
|
|
def predict(self, request: Request) -> str:
|
|
rng = jax.random.PRNGKey(request.rng_seed)
|
|
token_ids = self.text_to_token_ids(request.prompt)
|
|
rng, gen_rng = jax.random.split(rng)
|
|
|
|
inputs = np.array(token_ids, dtype=np.int32)[np.newaxis, :]
|
|
|
|
token_ids = jnp.array(inputs)
|
|
state = self.runner.params
|
|
|
|
settings = SampleSettings(
|
|
temperature=jnp.array([request.temperature]),
|
|
nucleus_p=jnp.array([request.nucleus_p]),
|
|
mask=jnp.ones(token_ids.shape, dtype=bool),
|
|
active=jnp.ones(token_ids.shape, dtype=bool),
|
|
)
|
|
|
|
for _ in range(request.max_len):
|
|
lm_outputs = self.runner.eval_forward(token_ids)
|
|
sample_output = sample_token(gen_rng, lm_outputs, settings)
|
|
new_token = sample_output.token_id
|
|
token_ids = jnp.concatenate([token_ids, new_token], axis=-1)
|
|
if jnp.argmax(new_token) == 0:
|
|
break
|
|
|
|
return self.tokenizer.decode(token_ids.squeeze())
|
|
|
|
def main():
|
|
runner = ModelRunner(
|
|
model=LanguageModelConfig(),
|
|
checkpoint_path="path_to_checkpoint",
|
|
)
|
|
inference_runner = InferenceRunner(
|
|
name="inference",
|
|
runner=runner,
|
|
load="path_to_load",
|
|
tokenizer_path="path_to_tokenizer_model",
|
|
local_mesh_config=(1, 1),
|
|
between_hosts_config=(1, 1),
|
|
)
|
|
inference_runner.initialize()
|
|
request = Request(
|
|
prompt="Sample text",
|
|
temperature=0.7,
|
|
nucleus_p=0.9,
|
|
rng_seed=42,
|
|
max_len=100,
|
|
)
|
|
result = inference_runner.predict(request)
|
|
print(result)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|