mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-26 21:49:53 +03:00
Merge fb9d433aa9
into 7050ed204b
This commit is contained in:
commit
11b95e7a26
109
checkpoint.py
109
checkpoint.py
@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def copy_to_shm(file: str):
|
def copy_to_shm(file: str):
|
||||||
|
"""
|
||||||
|
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
|
||||||
|
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
|
||||||
|
to the temporary file, and cleans up by removing the temporary file after use.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file (str): The path to the file to be copied.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
- str: The path to the file in shared memory.
|
||||||
|
"""
|
||||||
if file.startswith("/dev/shm/"):
|
if file.startswith("/dev/shm/"):
|
||||||
# Nothing to do, the file is already in shared memory.
|
# Nothing to do, the file is already in shared memory.
|
||||||
yield file
|
yield file
|
||||||
@ -58,6 +69,17 @@ def copy_to_shm(file: str):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def copy_from_shm(file: str):
|
def copy_from_shm(file: str):
|
||||||
|
"""
|
||||||
|
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
|
||||||
|
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
|
||||||
|
file to the specified path, cleaning up the temporary file afterwards.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file (str): The target path for the file to be copied to from shared memory.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
- str: The path to the temporary file in shared memory.
|
||||||
|
"""
|
||||||
tmp_dir = "/dev/shm/"
|
tmp_dir = "/dev/shm/"
|
||||||
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
||||||
try:
|
try:
|
||||||
@ -69,19 +91,48 @@ def copy_from_shm(file: str):
|
|||||||
|
|
||||||
|
|
||||||
def fast_unpickle(path: str) -> Any:
|
def fast_unpickle(path: str) -> Any:
|
||||||
|
"""
|
||||||
|
Unpickles and loads an object from a file, optionally using shared memory for faster access.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- path (str): The path to the pickle file to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Any: The object loaded from the pickle file.
|
||||||
|
"""
|
||||||
with copy_to_shm(path) as tmp_path:
|
with copy_to_shm(path) as tmp_path:
|
||||||
with open(tmp_path, "rb") as f:
|
with open(tmp_path, "rb") as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
def fast_pickle(obj: Any, path: str) -> None:
|
def fast_pickle(obj: Any, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Pickles and saves an object to a file, optionally using shared memory for faster access.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- obj (Any): The object to be pickled and saved.
|
||||||
|
- path (str): The path where the pickle file should be saved.
|
||||||
|
"""
|
||||||
with copy_from_shm(path) as tmp_path:
|
with copy_from_shm(path) as tmp_path:
|
||||||
with open(tmp_path, "wb") as f:
|
with open(tmp_path, "wb") as f:
|
||||||
pickle.dump(obj, f)
|
pickle.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||||
"""Loads a set of arrays."""
|
"""
|
||||||
|
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
|
||||||
|
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
|
||||||
|
in a separate file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
|
||||||
|
- directory (str): The directory from which to load the tensors.
|
||||||
|
- mesh_config: Configuration for data parallelism across processes or devices.
|
||||||
|
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- List of numpy arrays or zeros depending on the process's role in data parallelism.
|
||||||
|
"""
|
||||||
pool = ThreadPoolExecutor(max_workers=32)
|
pool = ThreadPoolExecutor(max_workers=32)
|
||||||
fs = list()
|
fs = list()
|
||||||
num_tensors = 0
|
num_tensors = 0
|
||||||
@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
|
|
||||||
|
|
||||||
def path_tuple_to_string(path: tuple) -> str:
|
def path_tuple_to_string(path: tuple) -> str:
|
||||||
|
"""
|
||||||
|
Converts a tuple representing a path in a nested structure to a string. This function is specifically
|
||||||
|
used for handling paths within the structure of saved states or parameters, making it easier to
|
||||||
|
identify and manipulate specific elements.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- path (tuple): A tuple representing a path in a nested structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str: A string representation of the path, suitable for logging or identification.
|
||||||
|
"""
|
||||||
pieces = []
|
pieces = []
|
||||||
for elem in path:
|
for elem in path:
|
||||||
if isinstance(elem, jax.tree_util.DictKey):
|
if isinstance(elem, jax.tree_util.DictKey):
|
||||||
@ -124,6 +186,19 @@ def get_load_path_str(
|
|||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
load_exclude_rules: Optional[list[str]] = None,
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
|
||||||
|
function is used in the context of loading states or parameters from files, allowing for flexible
|
||||||
|
mapping or exclusion based on pattern matching.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- init_path_str (str): The initial path string to process.
|
||||||
|
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
|
||||||
|
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Optional[str]: The processed load path string, or None if excluded.
|
||||||
|
"""
|
||||||
# Exclusion
|
# Exclusion
|
||||||
if load_exclude_rules is not None:
|
if load_exclude_rules is not None:
|
||||||
for search_pattern in load_exclude_rules:
|
for search_pattern in load_exclude_rules:
|
||||||
@ -148,6 +223,21 @@ def replace_with_load_state(
|
|||||||
load_exclude_rules: Optional[list[str]] = None,
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
mesh_config: tuple = (1, 1),
|
mesh_config: tuple = (1, 1),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
|
||||||
|
rules. This function supports conditional inclusion and transformation of state elements based on complex
|
||||||
|
criteria, facilitating flexible state restoration.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- init_state (Any): The initial state before replacement.
|
||||||
|
- load_state (Any): The state from which to load replacements.
|
||||||
|
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
|
||||||
|
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||||
|
- mesh_config (tuple): Configuration for data parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Any: The initial state with elements replaced from the load state.
|
||||||
|
"""
|
||||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||||
@ -186,6 +276,23 @@ def restore(
|
|||||||
state_sharding,
|
state_sharding,
|
||||||
init_state: Optional[Any] = None,
|
init_state: Optional[Any] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
|
||||||
|
configurations. This function is designed for restoring model states from disk with support for distributed
|
||||||
|
environments, handling the intricacies of partitioning and host-specific configurations.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- checkpoint_path (str): The path to the checkpoint directory.
|
||||||
|
- state_shapes (Any): The expected shapes of the state to restore.
|
||||||
|
- mesh: The mesh configuration for distributed environments.
|
||||||
|
- between_hosts_config: Configuration for data exchange between hosts.
|
||||||
|
- params_only (bool): Whether to restore parameters only, excluding other state parts.
|
||||||
|
- state_sharding: Sharding configuration for the state.
|
||||||
|
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Any: The restored state, potentially sharded across the distributed environment.
|
||||||
|
"""
|
||||||
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
||||||
|
|
||||||
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
||||||
|
21
run.py
21
run.py
@ -22,6 +22,27 @@ CKPT_PATH = "./checkpoints/"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
Initializes and runs a text generation model using predefined model configurations and inference settings.
|
||||||
|
|
||||||
|
This function sets up a language model with specific configurations, including model architecture details
|
||||||
|
(e.g., embedding sizes, number of layers, attention heads, and MoE settings) and text generation settings
|
||||||
|
(e.g., vocabulary size, token identifiers). It initializes an inference runner with the model, checkpoint
|
||||||
|
path, tokenizer, and mesh configuration. The inference runner is then used to generate text based on a
|
||||||
|
given prompt and output the result.
|
||||||
|
|
||||||
|
The process involves:
|
||||||
|
- Creating a `LanguageModelConfig` instance with specified model parameters, including transformer
|
||||||
|
configurations and quantization settings for weights.
|
||||||
|
- Initializing an `InferenceRunner` with the model configuration, batch size per device, checkpoint path,
|
||||||
|
and other relevant settings.
|
||||||
|
- Calling the `initialize` method on the inference runner to prepare the model and tokenizer for inference.
|
||||||
|
- Generating text based on a provided prompt using the `sample_from_model` function, which internally
|
||||||
|
manages the sampling process through the inference runner.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- Prints the generated text continuation for a prompt to the standard output.
|
||||||
|
"""
|
||||||
grok_1_model = LanguageModelConfig(
|
grok_1_model = LanguageModelConfig(
|
||||||
vocab_size=128 * 1024,
|
vocab_size=128 * 1024,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
|
252
runners.py
252
runners.py
@ -48,6 +48,20 @@ TOP_K = 8
|
|||||||
|
|
||||||
|
|
||||||
class SampleSettings(NamedTuple):
|
class SampleSettings(NamedTuple):
|
||||||
|
"""
|
||||||
|
A NamedTuple for storing settings used during the sampling process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
|
||||||
|
the model outputs more deterministic, while higher values increase diversity.
|
||||||
|
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
|
||||||
|
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
|
||||||
|
of plausible tokens.
|
||||||
|
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
|
||||||
|
for sampling in each batch element.
|
||||||
|
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
|
||||||
|
generation. Elements not in use are ignored in computations to save resources.
|
||||||
|
"""
|
||||||
temperature: ArrayLike
|
temperature: ArrayLike
|
||||||
nucleus_p: ArrayLike
|
nucleus_p: ArrayLike
|
||||||
mask: ArrayLike
|
mask: ArrayLike
|
||||||
@ -56,6 +70,15 @@ class SampleSettings(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class SampleOutput(NamedTuple):
|
class SampleOutput(NamedTuple):
|
||||||
|
"""
|
||||||
|
A NamedTuple for storing the output from the sampling process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- token_id (ArrayLike): The sampled token ID for each batch element.
|
||||||
|
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
|
||||||
|
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
|
||||||
|
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
|
||||||
|
"""
|
||||||
token_id: ArrayLike
|
token_id: ArrayLike
|
||||||
prob: ArrayLike
|
prob: ArrayLike
|
||||||
top_k_token_ids: ArrayLike
|
top_k_token_ids: ArrayLike
|
||||||
@ -63,6 +86,19 @@ class SampleOutput(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
def insert_slice(memory: Memory, slice, length, i):
|
def insert_slice(memory: Memory, slice, length, i):
|
||||||
|
"""
|
||||||
|
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
|
||||||
|
object with a new slice that includes updated steps for each layer based on the provided length.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- memory (Memory): The original memory object to be updated.
|
||||||
|
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
|
||||||
|
- length (int): The step length to set for each KVMemory layer in the inserted slice.
|
||||||
|
- i (int): The index at which the slice is to be inserted into the memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Memory: The updated memory object with the new slice inserted at the specified index.
|
||||||
|
"""
|
||||||
slice = Memory(
|
slice = Memory(
|
||||||
layers=[
|
layers=[
|
||||||
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
||||||
@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i):
|
|||||||
|
|
||||||
|
|
||||||
def pad_to_size(x, size):
|
def pad_to_size(x, size):
|
||||||
|
"""
|
||||||
|
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
|
||||||
|
truncated from the left. If it is shorter, it will be right-padded with zeros.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x (ArrayLike): The array to be padded or truncated.
|
||||||
|
- size (int): The target size for the array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ArrayLike: The array adjusted to the specified size.
|
||||||
|
"""
|
||||||
if x.shape[0] > size:
|
if x.shape[0] > size:
|
||||||
# Left truncate if the context is too long.
|
# Left truncate if the context is too long.
|
||||||
x = x[-size:]
|
x = x[-size:]
|
||||||
@ -82,7 +129,20 @@ def pad_to_size(x, size):
|
|||||||
|
|
||||||
|
|
||||||
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
||||||
"""Performs nucleus filtering on logits."""
|
"""
|
||||||
|
Performs nucleus filtering on logits.
|
||||||
|
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
|
||||||
|
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
|
||||||
|
the quality of generated text.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- logits (jax.Array): The logits to filter.
|
||||||
|
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
|
||||||
|
from consideration during sampling.
|
||||||
|
"""
|
||||||
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
||||||
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
||||||
sorted_probs = jax.nn.softmax(sorted_logits)
|
sorted_probs = jax.nn.softmax(sorted_logits)
|
||||||
@ -102,6 +162,21 @@ def sample_token(
|
|||||||
lm_outputs: LanguageModelOutput,
|
lm_outputs: LanguageModelOutput,
|
||||||
settings: SampleSettings,
|
settings: SampleSettings,
|
||||||
) -> SampleOutput:
|
) -> SampleOutput:
|
||||||
|
"""
|
||||||
|
Samples a token from the language model outputs using specified settings, including temperature and
|
||||||
|
nucleus probability filtering. This function also computes the probability of the sampled token and
|
||||||
|
identifies the top-k tokens and their probabilities.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
|
||||||
|
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
|
||||||
|
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
|
||||||
|
nucleus probability, token masking, and active batch elements indicator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
|
||||||
|
and the top-k tokens and their probabilities for each batch element.
|
||||||
|
"""
|
||||||
# Expand the settings shape to match the logit shape.
|
# Expand the settings shape to match the logit shape.
|
||||||
settings = SampleSettings(
|
settings = SampleSettings(
|
||||||
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
|
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
|
||||||
@ -135,6 +210,25 @@ def sample_token(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
"""
|
||||||
|
Manages the execution of a language model, including initialization, forward passes, and state management.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- model (LanguageModelConfig): The configuration for the language model.
|
||||||
|
- bs_per_device (float): The batch size per device.
|
||||||
|
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
|
||||||
|
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
|
||||||
|
- rng_seed (int): The initial random number generator seed.
|
||||||
|
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
|
||||||
|
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
|
||||||
|
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
|
||||||
|
- init: Initializes the model parameters using provided data.
|
||||||
|
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
|
||||||
|
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
|
||||||
|
"""
|
||||||
model: LanguageModelConfig
|
model: LanguageModelConfig
|
||||||
|
|
||||||
bs_per_device: float = 2.0
|
bs_per_device: float = 2.0
|
||||||
@ -148,6 +242,18 @@ class ModelRunner:
|
|||||||
checkpoint_path: str = ""
|
checkpoint_path: str = ""
|
||||||
|
|
||||||
def make_forward_fn(self, mesh: Any):
|
def make_forward_fn(self, mesh: Any):
|
||||||
|
"""
|
||||||
|
Creates and optionally transforms the forward function of the model. This method constructs a forward
|
||||||
|
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
|
||||||
|
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
|
||||||
|
and execution on the mesh.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: The forward function, potentially transformed by Haiku, ready for execution.
|
||||||
|
"""
|
||||||
def forward(tokens):
|
def forward(tokens):
|
||||||
out = self.model.make(mesh=mesh)(tokens)
|
out = self.model.make(mesh=mesh)(tokens)
|
||||||
return out, None
|
return out, None
|
||||||
@ -162,6 +268,20 @@ class ModelRunner:
|
|||||||
local_mesh_config: tuple[int, int],
|
local_mesh_config: tuple[int, int],
|
||||||
between_hosts_config: tuple[int, int],
|
between_hosts_config: tuple[int, int],
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
|
||||||
|
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
|
||||||
|
the forward function for execution.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
init_data: Initial data used for model initialization. This is typically dummy data matching the
|
||||||
|
expected input format and dimensions of the model.
|
||||||
|
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
|
||||||
|
are organized locally for parallel computation.
|
||||||
|
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
|
||||||
|
distributed setup, important for multi-host environments.
|
||||||
|
|
||||||
|
"""
|
||||||
num_replicas = math.prod(between_hosts_config)
|
num_replicas = math.prod(between_hosts_config)
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
self.model.fprop_dtype = jnp.bfloat16
|
self.model.fprop_dtype = jnp.bfloat16
|
||||||
@ -191,12 +311,37 @@ class ModelRunner:
|
|||||||
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
||||||
|
|
||||||
def init(self, rng: jax.Array, data) -> TrainingState:
|
def init(self, rng: jax.Array, data) -> TrainingState:
|
||||||
|
"""
|
||||||
|
Initializes the model parameters using provided data and a random number generator state. This method
|
||||||
|
is only called when `transform_forward` is True, indicating that the forward function has been transformed
|
||||||
|
by Haiku and requires explicit initialization.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
rng (jax.Array): The random number generator state for initializing model parameters.
|
||||||
|
data: The initial data used for parameter initialization, usually including input tokens and
|
||||||
|
possibly target tokens for supervised learning tasks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainingState: An instance of TrainingState containing the initialized model parameters.
|
||||||
|
"""
|
||||||
assert self.transform_forward
|
assert self.transform_forward
|
||||||
rng, init_rng = jax.random.split(rng)
|
rng, init_rng = jax.random.split(rng)
|
||||||
params = self.forward.init(init_rng, data["inputs"])
|
params = self.forward.init(init_rng, data["inputs"])
|
||||||
return TrainingState(params=params)
|
return TrainingState(params=params)
|
||||||
|
|
||||||
def get_state_sharding(self, init_data):
|
def get_state_sharding(self, init_data):
|
||||||
|
"""
|
||||||
|
Determines the sharding configuration for model parameters based on initialization data. This method
|
||||||
|
evaluates the shape of model parameters during initialization to apply the partition rules specified
|
||||||
|
in the model's configuration, facilitating efficient distributed computation.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
init_data: The initial data used for model initialization, which helps determine the appropriate
|
||||||
|
sharding configuration based on model parameter shapes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
|
||||||
|
"""
|
||||||
assert self.transform_forward
|
assert self.transform_forward
|
||||||
rng = jax.random.PRNGKey(self.rng_seed)
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||||||
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
||||||
@ -215,6 +360,20 @@ class ModelRunner:
|
|||||||
from_checkpoint: bool = True,
|
from_checkpoint: bool = True,
|
||||||
init_fn: Optional[Callable] = None,
|
init_fn: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
|
||||||
|
method provides flexibility in starting the model from a known state or freshly initializing parameters
|
||||||
|
based on the model configuration and provided initial data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
init_data: Initial data used for model parameter initialization if needed.
|
||||||
|
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
|
||||||
|
init_fn (Optional[Callable]): An optional initialization function that overrides the default
|
||||||
|
initialization method if provided.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model state, either loaded from a checkpoint or newly initialized.
|
||||||
|
"""
|
||||||
rng = jax.random.PRNGKey(self.rng_seed)
|
rng = jax.random.PRNGKey(self.rng_seed)
|
||||||
|
|
||||||
if not self.checkpoint_path or not from_checkpoint:
|
if not self.checkpoint_path or not from_checkpoint:
|
||||||
@ -251,6 +410,16 @@ class ModelRunner:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Request:
|
class Request:
|
||||||
|
"""
|
||||||
|
Data class for storing information about a single sampling request to the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt (str): The text prompt to generate text from.
|
||||||
|
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
|
||||||
|
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
|
||||||
|
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
|
||||||
|
max_len (int): The maximum length of the generated text sequence.
|
||||||
|
"""
|
||||||
prompt: str
|
prompt: str
|
||||||
temperature: float
|
temperature: float
|
||||||
nucleus_p: float
|
nucleus_p: float
|
||||||
@ -260,6 +429,24 @@ class Request:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceRunner:
|
class InferenceRunner:
|
||||||
|
"""
|
||||||
|
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
|
||||||
|
prefilling the model memory, and running the actual text generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): A name identifier for the inference runner.
|
||||||
|
runner (Any): An instance of ModelRunner that manages the underlying language model.
|
||||||
|
load (str): A path to the model checkpoint for loading.
|
||||||
|
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
|
||||||
|
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
|
||||||
|
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
|
||||||
|
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
get_pad_bucket: Determines the appropriate padding size for a given input length.
|
||||||
|
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
|
||||||
|
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
|
||||||
|
"""
|
||||||
name: str
|
name: str
|
||||||
runner: Any
|
runner: Any
|
||||||
load: str
|
load: str
|
||||||
@ -269,10 +456,25 @@ class InferenceRunner:
|
|||||||
pad_sizes: tuple[int] = (1024,)
|
pad_sizes: tuple[int] = (1024,)
|
||||||
|
|
||||||
def get_pad_bucket(self, size):
|
def get_pad_bucket(self, size):
|
||||||
|
"""
|
||||||
|
Determines the appropriate padding size for a given input length. This method uses the predefined
|
||||||
|
pad sizes to find the smallest pad size that is larger than or equal to the given size.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
size (int): The length of the input sequence to be padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The selected pad size from the predefined pad sizes that best fits the input length.
|
||||||
|
"""
|
||||||
i = bisect.bisect_left(self.pad_sizes, size)
|
i = bisect.bisect_left(self.pad_sizes, size)
|
||||||
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
"""
|
||||||
|
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
|
||||||
|
model functions for different input sizes, and loading or initializing model parameters. This method
|
||||||
|
ensures that the inference runner is ready to generate text based on prompts.
|
||||||
|
"""
|
||||||
runner = self.runner
|
runner = self.runner
|
||||||
self.runner.transform_forward = True
|
self.runner.transform_forward = True
|
||||||
dummy_data = dict(
|
dummy_data = dict(
|
||||||
@ -440,7 +642,15 @@ class InferenceRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Generator that accepts prompts."""
|
"""
|
||||||
|
Generator method that accepts prompts and manages the text generation process. This method handles
|
||||||
|
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
|
||||||
|
generate and return the generated text.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: The generated text for each prompt received. This method operates as a coroutine, yielding
|
||||||
|
generated text back to the caller and awaiting new prompts for further generation.
|
||||||
|
"""
|
||||||
runner = self.runner
|
runner = self.runner
|
||||||
mesh = runner.mesh
|
mesh = runner.mesh
|
||||||
max_len = runner.model.sequence_len
|
max_len = runner.model.sequence_len
|
||||||
@ -580,6 +790,18 @@ class InferenceRunner:
|
|||||||
def make_mesh(
|
def make_mesh(
|
||||||
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
|
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
|
||||||
) -> jax.sharding.Mesh:
|
) -> jax.sharding.Mesh:
|
||||||
|
"""
|
||||||
|
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
|
||||||
|
to how local devices are arranged and partitioned for computation.
|
||||||
|
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
|
||||||
|
different hosts, relevant in multi-host distributed setups.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
|
||||||
|
"""
|
||||||
assert len(local_mesh_config) == 2
|
assert len(local_mesh_config) == 2
|
||||||
assert len(between_hosts_config) == 2
|
assert len(between_hosts_config) == 2
|
||||||
rank_logger.info("Detected %s devices in mesh", jax.device_count())
|
rank_logger.info("Detected %s devices in mesh", jax.device_count())
|
||||||
@ -594,6 +816,32 @@ def make_mesh(
|
|||||||
|
|
||||||
|
|
||||||
def sample_from_model(server, prompt, max_len, temperature):
|
def sample_from_model(server, prompt, max_len, temperature):
|
||||||
|
"""
|
||||||
|
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
|
||||||
|
to interact with a generator-based server that manages the state of the model and performs the actual
|
||||||
|
sampling. The server is expected to yield control back to this function, which then sends a request
|
||||||
|
object containing the prompt and settings, and waits for the generated output.
|
||||||
|
|
||||||
|
The function initializes the server (if not already done), constructs the request with the provided
|
||||||
|
parameters, and sends this request to the server. The sampled output, typically a continuation of the
|
||||||
|
prompt, is then received from the server.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- server (generator): The generator-based server that handles model inference. This server is expected
|
||||||
|
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
|
||||||
|
performing the token sampling.
|
||||||
|
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
|
||||||
|
encoded and fed into the model as part of the request.
|
||||||
|
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
|
||||||
|
This controls how long the model's output can be.
|
||||||
|
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
|
||||||
|
model's outputs more deterministic (and potentially more repetitive), while higher values increase
|
||||||
|
diversity at the cost of coherence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
|
||||||
|
by `max_len`.
|
||||||
|
"""
|
||||||
next(server)
|
next(server)
|
||||||
inp = Request(
|
inp = Request(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user