grok-1/runners.py
2024-08-07 19:54:48 -04:00

310 lines
10 KiB
Python

# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
import functools
import logging
import math
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import sentencepiece
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional, Tuple
from jax.experimental import pjit
from jax.sharding import PartitionSpec as P
from model import (
LanguageModelConfig,
LanguageModelOutput,
TrainingState,
apply_rules,
Memory,
KVMemory,
)
import checkpoint as xai_checkpoint
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
rank_logger = logging.getLogger("rank")
rank_logger.setLevel(logging.INFO)
TOP_K = 8
class SampleSettings(NamedTuple):
temperature: jax.Array
nucleus_p: jax.Array
mask: jax.Array
active: jax.Array
class SampleOutput(NamedTuple):
token_id: jax.Array
prob: jax.Array
top_k_token_ids: jax.Array
top_k_probs: jax.Array
def insert_slice(memory: Memory, slice: Memory, length: int, i: int) -> Memory:
slice = Memory(
layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length]))
for layer in slice.layers
],
)
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
memory, slice)
def pad_to_size(x: jnp.ndarray, size: int) -> jnp.ndarray:
if x.shape[0] > size:
x = x[-size:]
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits)
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
threshold_largest_logits = jnp.take_along_axis(
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
)
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
mask = logits >= threshold_largest_logits
logits = jnp.where(mask, logits, -1e10)
return logits
def sample_token(
rngs: jax.random.PRNGKey,
lm_outputs: LanguageModelOutput,
settings: SampleSettings,
) -> SampleOutput:
settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)),
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)),
mask=jnp.expand_dims(settings.mask, 1),
active=settings.active,
)
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
logits = jnp.where(settings.mask, logits, -1e10)
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))
new_token = jax.vmap(jax.random.categorical)(rngs, logits)
probabilities = jax.nn.softmax(logits)
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
token_prob = jnp.squeeze(token_prob, 1)
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
top_k_probs = jnp.squeeze(top_k_probs, 1)
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
return SampleOutput(
new_token,
token_prob,
top_k_token_ids,
top_k_probs,
)
@dataclass
class ModelRunner:
model: LanguageModelConfig
bs_per_device: float = 2.0
load_rename_rules: Optional[list[tuple[str, str]]] = None
load_exclude_rules: Optional[list[str]] = None
rng_seed: int = 42
transform_forward: bool = False
checkpoint_path: str = ""
def make_forward_fn(self, mesh: Any):
def forward(tokens):
out = self.model.make(mesh=mesh)(tokens)
return out, None
if self.transform_forward:
forward = hk.transform(forward)
return forward
def initialize(
self,
init_data,
local_mesh_config: Tuple[int, int],
between_hosts_config: Tuple[int, int],
):
num_replicas = math.prod(between_hosts_config)
self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices())
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)
self.local_batch_size = self.batch_size // jax.process_count()
self.local_mesh_config = local_mesh_config
self.between_hosts_config = between_hosts_config
rank_logger.info(
f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
)
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
self.forward = self.make_forward_fn(mesh=self.mesh)
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])
self.eval_forward = self.make_forward_fn(mesh=self.mesh)
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])
if self.transform_forward:
self.state_sharding = self.get_state_sharding(init_data)
rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
def init(self, rng: jax.Array, data) -> TrainingState:
assert self.transform_forward
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data["inputs"])
return TrainingState(params=params)
def get_state_sharding(self, init_data):
assert self.transform_forward
rng = jax.random.PRNGKey(self.rng_seed)
rank_logger.info(f"partition rules: {self.model.partition_rules}")
with self.mesh:
shapes = jax.eval_shape(self.init, rng, init_data)
sharding = jax.tree_util.tree_map_with_path(
apply_rules(self.model.partition_rules()),
shapes,
)
return sharding
def load_or_init(
self,
init_data: Any,
from_checkpoint: bool = True,
init_fn: Optional[Callable] = None,
):
rng = jax.random.PRNGKey(self.rng_seed)
if not self.checkpoint_path or not from_checkpoint:
rank_logger.info("Initializing model...")
with self.mesh:
if init_fn is not None:
state = init_fn(rng, init_data)
else:
assert self.transform_forward
state = self.init_fn(rng, init_data)
rank_logger.info("Model state is newly initialized.")
else:
with self.mesh:
if init_fn:
state_shapes = jax.eval_shape(init_fn, rng, init_data)
else:
assert self.transform_forward
state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
init_state = None
state = xai_checkpoint.restore(
checkpoint_path=self.checkpoint_path,
state_shapes=state_shapes,
mesh=self.mesh,
between_hosts_config=self.between_hosts_config,
state_sharding=self.state_sharding,
init_state=init_state,
params_only=True,
)
del init_state
return state
@dataclass
class Request:
prompt: str
temperature: float
nucleus_p: float
rng_seed: int
max_len: int
@dataclass
class InferenceRunner:
name: str
runner: ModelRunner
load: str
tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
local_mesh_config: Tuple[int, int] = (1, 1)
between_hosts_config: Tuple[int, int] = (1, 1)
pad_sizes: Tuple[int] = (1024,)
def get_pad_bucket(self, size: int) -> int:
i = bisect.bisect_left(self.pad_sizes, size)
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
def initialize(self):
runner = self.runner
self.runner.transform_forward = True
dummy_data = dict(
inputs=np.zeros((1, self.get_pad_bucket(512)), dtype=np.int32),
)
state = runner.load_or_init(
dummy_data,
from_checkpoint=False,
)
runner.params = state.params
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)
def text_to_token_ids(text):
ids = self.tokenizer.encode(text, out_type=int)
return ids
self.text_to_token_ids = text_to_token_ids
def predict(self, request: Request) -> str:
rng = jax.random.PRNGKey(request.rng_seed)
token_ids = self.text_to_token_ids(request.prompt)
rng, gen_rng = jax.random.split(rng)
inputs = np.array(token_ids, dtype=np.int32)[np.newaxis, :]
token_ids = jnp.array(inputs)
state = self.runner.params
settings = SampleSettings(
temperature=jnp.array([request.temperature]),
nucleus_p=jnp.array([request.nucleus_p]),
mask=jnp.ones(token_ids.shape, dtype=bool),
active=jnp.ones(token_ids.shape, dtype=bool),
)
for _ in range(request.max_len):
lm_outputs = self.runner.eval_forward(token_ids)
sample_output = sample_token(gen_rng, lm_outputs, settings)
new_token = sample_output.token_id
token_ids = jnp.concatenate([token_ids, new_token], axis=-1)
if jnp.argmax(new_token) == 0:
break
return self.tokenizer.decode(token_ids.squeeze())
def main():
runner = ModelRunner(
model=LanguageModelConfig(),
checkpoint_path="path_to_checkpoint",
)
inference_runner = InferenceRunner(
name="inference",
runner=runner,
load="path_to_load",
tokenizer_path="path_to_tokenizer_model",
local_mesh_config=(1, 1),
between_hosts_config=(1, 1),
)
inference_runner.initialize()
request = Request(
prompt="Sample text",
temperature=0.7,
nucleus_p=0.9,
rng_seed=42,
max_len=100,
)
result = inference_runner.predict(request)
print(result)
if __name__ == "__main__":
main()