mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-26 21:49:53 +03:00
Merge fb9d433aa9
into 7050ed204b
This commit is contained in:
commit
11b95e7a26
109
checkpoint.py
109
checkpoint.py
@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
||||
|
||||
@contextlib.contextmanager
|
||||
def copy_to_shm(file: str):
|
||||
"""
|
||||
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
|
||||
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
|
||||
to the temporary file, and cleans up by removing the temporary file after use.
|
||||
|
||||
Parameters:
|
||||
- file (str): The path to the file to be copied.
|
||||
|
||||
Yields:
|
||||
- str: The path to the file in shared memory.
|
||||
"""
|
||||
if file.startswith("/dev/shm/"):
|
||||
# Nothing to do, the file is already in shared memory.
|
||||
yield file
|
||||
@ -58,6 +69,17 @@ def copy_to_shm(file: str):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def copy_from_shm(file: str):
|
||||
"""
|
||||
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
|
||||
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
|
||||
file to the specified path, cleaning up the temporary file afterwards.
|
||||
|
||||
Parameters:
|
||||
- file (str): The target path for the file to be copied to from shared memory.
|
||||
|
||||
Yields:
|
||||
- str: The path to the temporary file in shared memory.
|
||||
"""
|
||||
tmp_dir = "/dev/shm/"
|
||||
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
||||
try:
|
||||
@ -69,19 +91,48 @@ def copy_from_shm(file: str):
|
||||
|
||||
|
||||
def fast_unpickle(path: str) -> Any:
|
||||
"""
|
||||
Unpickles and loads an object from a file, optionally using shared memory for faster access.
|
||||
|
||||
Parameters:
|
||||
- path (str): The path to the pickle file to load.
|
||||
|
||||
Returns:
|
||||
- Any: The object loaded from the pickle file.
|
||||
"""
|
||||
with copy_to_shm(path) as tmp_path:
|
||||
with open(tmp_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def fast_pickle(obj: Any, path: str) -> None:
|
||||
"""
|
||||
Pickles and saves an object to a file, optionally using shared memory for faster access.
|
||||
|
||||
Parameters:
|
||||
- obj (Any): The object to be pickled and saved.
|
||||
- path (str): The path where the pickle file should be saved.
|
||||
"""
|
||||
with copy_from_shm(path) as tmp_path:
|
||||
with open(tmp_path, "wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
"""Loads a set of arrays."""
|
||||
"""
|
||||
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
|
||||
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
|
||||
in a separate file.
|
||||
|
||||
Parameters:
|
||||
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
|
||||
- directory (str): The directory from which to load the tensors.
|
||||
- mesh_config: Configuration for data parallelism across processes or devices.
|
||||
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
|
||||
|
||||
Returns:
|
||||
- List of numpy arrays or zeros depending on the process's role in data parallelism.
|
||||
"""
|
||||
pool = ThreadPoolExecutor(max_workers=32)
|
||||
fs = list()
|
||||
num_tensors = 0
|
||||
@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
|
||||
|
||||
def path_tuple_to_string(path: tuple) -> str:
|
||||
"""
|
||||
Converts a tuple representing a path in a nested structure to a string. This function is specifically
|
||||
used for handling paths within the structure of saved states or parameters, making it easier to
|
||||
identify and manipulate specific elements.
|
||||
|
||||
Parameters:
|
||||
- path (tuple): A tuple representing a path in a nested structure.
|
||||
|
||||
Returns:
|
||||
- str: A string representation of the path, suitable for logging or identification.
|
||||
"""
|
||||
pieces = []
|
||||
for elem in path:
|
||||
if isinstance(elem, jax.tree_util.DictKey):
|
||||
@ -124,6 +186,19 @@ def get_load_path_str(
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
|
||||
function is used in the context of loading states or parameters from files, allowing for flexible
|
||||
mapping or exclusion based on pattern matching.
|
||||
|
||||
Parameters:
|
||||
- init_path_str (str): The initial path string to process.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
|
||||
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||
|
||||
Returns:
|
||||
- Optional[str]: The processed load path string, or None if excluded.
|
||||
"""
|
||||
# Exclusion
|
||||
if load_exclude_rules is not None:
|
||||
for search_pattern in load_exclude_rules:
|
||||
@ -148,6 +223,21 @@ def replace_with_load_state(
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
mesh_config: tuple = (1, 1),
|
||||
) -> Any:
|
||||
"""
|
||||
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
|
||||
rules. This function supports conditional inclusion and transformation of state elements based on complex
|
||||
criteria, facilitating flexible state restoration.
|
||||
|
||||
Parameters:
|
||||
- init_state (Any): The initial state before replacement.
|
||||
- load_state (Any): The state from which to load replacements.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
|
||||
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||
- mesh_config (tuple): Configuration for data parallelism.
|
||||
|
||||
Returns:
|
||||
- Any: The initial state with elements replaced from the load state.
|
||||
"""
|
||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||
@ -186,6 +276,23 @@ def restore(
|
||||
state_sharding,
|
||||
init_state: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
|
||||
configurations. This function is designed for restoring model states from disk with support for distributed
|
||||
environments, handling the intricacies of partitioning and host-specific configurations.
|
||||
|
||||
Parameters:
|
||||
- checkpoint_path (str): The path to the checkpoint directory.
|
||||
- state_shapes (Any): The expected shapes of the state to restore.
|
||||
- mesh: The mesh configuration for distributed environments.
|
||||
- between_hosts_config: Configuration for data exchange between hosts.
|
||||
- params_only (bool): Whether to restore parameters only, excluding other state parts.
|
||||
- state_sharding: Sharding configuration for the state.
|
||||
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
|
||||
|
||||
Returns:
|
||||
- Any: The restored state, potentially sharded across the distributed environment.
|
||||
"""
|
||||
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
||||
|
||||
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
||||
|
21
run.py
21
run.py
@ -22,6 +22,27 @@ CKPT_PATH = "./checkpoints/"
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Initializes and runs a text generation model using predefined model configurations and inference settings.
|
||||
|
||||
This function sets up a language model with specific configurations, including model architecture details
|
||||
(e.g., embedding sizes, number of layers, attention heads, and MoE settings) and text generation settings
|
||||
(e.g., vocabulary size, token identifiers). It initializes an inference runner with the model, checkpoint
|
||||
path, tokenizer, and mesh configuration. The inference runner is then used to generate text based on a
|
||||
given prompt and output the result.
|
||||
|
||||
The process involves:
|
||||
- Creating a `LanguageModelConfig` instance with specified model parameters, including transformer
|
||||
configurations and quantization settings for weights.
|
||||
- Initializing an `InferenceRunner` with the model configuration, batch size per device, checkpoint path,
|
||||
and other relevant settings.
|
||||
- Calling the `initialize` method on the inference runner to prepare the model and tokenizer for inference.
|
||||
- Generating text based on a provided prompt using the `sample_from_model` function, which internally
|
||||
manages the sampling process through the inference runner.
|
||||
|
||||
Output:
|
||||
- Prints the generated text continuation for a prompt to the standard output.
|
||||
"""
|
||||
grok_1_model = LanguageModelConfig(
|
||||
vocab_size=128 * 1024,
|
||||
pad_token=0,
|
||||
|
252
runners.py
252
runners.py
@ -48,6 +48,20 @@ TOP_K = 8
|
||||
|
||||
|
||||
class SampleSettings(NamedTuple):
|
||||
"""
|
||||
A NamedTuple for storing settings used during the sampling process.
|
||||
|
||||
Attributes:
|
||||
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
|
||||
the model outputs more deterministic, while higher values increase diversity.
|
||||
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
|
||||
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
|
||||
of plausible tokens.
|
||||
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
|
||||
for sampling in each batch element.
|
||||
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
|
||||
generation. Elements not in use are ignored in computations to save resources.
|
||||
"""
|
||||
temperature: ArrayLike
|
||||
nucleus_p: ArrayLike
|
||||
mask: ArrayLike
|
||||
@ -56,6 +70,15 @@ class SampleSettings(NamedTuple):
|
||||
|
||||
|
||||
class SampleOutput(NamedTuple):
|
||||
"""
|
||||
A NamedTuple for storing the output from the sampling process.
|
||||
|
||||
Attributes:
|
||||
- token_id (ArrayLike): The sampled token ID for each batch element.
|
||||
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
|
||||
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
|
||||
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
|
||||
"""
|
||||
token_id: ArrayLike
|
||||
prob: ArrayLike
|
||||
top_k_token_ids: ArrayLike
|
||||
@ -63,6 +86,19 @@ class SampleOutput(NamedTuple):
|
||||
|
||||
|
||||
def insert_slice(memory: Memory, slice, length, i):
|
||||
"""
|
||||
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
|
||||
object with a new slice that includes updated steps for each layer based on the provided length.
|
||||
|
||||
Parameters:
|
||||
- memory (Memory): The original memory object to be updated.
|
||||
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
|
||||
- length (int): The step length to set for each KVMemory layer in the inserted slice.
|
||||
- i (int): The index at which the slice is to be inserted into the memory.
|
||||
|
||||
Returns:
|
||||
- Memory: The updated memory object with the new slice inserted at the specified index.
|
||||
"""
|
||||
slice = Memory(
|
||||
layers=[
|
||||
KVMemory(layer.k, layer.v, step=jnp.array([length]))
|
||||
@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i):
|
||||
|
||||
|
||||
def pad_to_size(x, size):
|
||||
"""
|
||||
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
|
||||
truncated from the left. If it is shorter, it will be right-padded with zeros.
|
||||
|
||||
Parameters:
|
||||
- x (ArrayLike): The array to be padded or truncated.
|
||||
- size (int): The target size for the array.
|
||||
|
||||
Returns:
|
||||
- ArrayLike: The array adjusted to the specified size.
|
||||
"""
|
||||
if x.shape[0] > size:
|
||||
# Left truncate if the context is too long.
|
||||
x = x[-size:]
|
||||
@ -82,7 +129,20 @@ def pad_to_size(x, size):
|
||||
|
||||
|
||||
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
|
||||
"""Performs nucleus filtering on logits."""
|
||||
"""
|
||||
Performs nucleus filtering on logits.
|
||||
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
|
||||
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
|
||||
the quality of generated text.
|
||||
|
||||
Parameters:
|
||||
- logits (jax.Array): The logits to filter.
|
||||
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
|
||||
|
||||
Returns:
|
||||
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
|
||||
from consideration during sampling.
|
||||
"""
|
||||
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
|
||||
sorted_logits = jax.lax.sort(logits, is_stable=False)
|
||||
sorted_probs = jax.nn.softmax(sorted_logits)
|
||||
@ -102,6 +162,21 @@ def sample_token(
|
||||
lm_outputs: LanguageModelOutput,
|
||||
settings: SampleSettings,
|
||||
) -> SampleOutput:
|
||||
"""
|
||||
Samples a token from the language model outputs using specified settings, including temperature and
|
||||
nucleus probability filtering. This function also computes the probability of the sampled token and
|
||||
identifies the top-k tokens and their probabilities.
|
||||
|
||||
Parameters:
|
||||
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
|
||||
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
|
||||
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
|
||||
nucleus probability, token masking, and active batch elements indicator.
|
||||
|
||||
Returns:
|
||||
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
|
||||
and the top-k tokens and their probabilities for each batch element.
|
||||
"""
|
||||
# Expand the settings shape to match the logit shape.
|
||||
settings = SampleSettings(
|
||||
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
|
||||
@ -135,6 +210,25 @@ def sample_token(
|
||||
|
||||
@dataclass
|
||||
class ModelRunner:
|
||||
"""
|
||||
Manages the execution of a language model, including initialization, forward passes, and state management.
|
||||
|
||||
Attributes:
|
||||
- model (LanguageModelConfig): The configuration for the language model.
|
||||
- bs_per_device (float): The batch size per device.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
|
||||
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
|
||||
- rng_seed (int): The initial random number generator seed.
|
||||
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
|
||||
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
|
||||
|
||||
Methods:
|
||||
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
|
||||
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
|
||||
- init: Initializes the model parameters using provided data.
|
||||
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
|
||||
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
|
||||
"""
|
||||
model: LanguageModelConfig
|
||||
|
||||
bs_per_device: float = 2.0
|
||||
@ -148,6 +242,18 @@ class ModelRunner:
|
||||
checkpoint_path: str = ""
|
||||
|
||||
def make_forward_fn(self, mesh: Any):
|
||||
"""
|
||||
Creates and optionally transforms the forward function of the model. This method constructs a forward
|
||||
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
|
||||
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
|
||||
and execution on the mesh.
|
||||
|
||||
Parameters:
|
||||
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
|
||||
|
||||
Returns:
|
||||
Callable: The forward function, potentially transformed by Haiku, ready for execution.
|
||||
"""
|
||||
def forward(tokens):
|
||||
out = self.model.make(mesh=mesh)(tokens)
|
||||
return out, None
|
||||
@ -162,6 +268,20 @@ class ModelRunner:
|
||||
local_mesh_config: tuple[int, int],
|
||||
between_hosts_config: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
|
||||
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
|
||||
the forward function for execution.
|
||||
|
||||
Parameters:
|
||||
init_data: Initial data used for model initialization. This is typically dummy data matching the
|
||||
expected input format and dimensions of the model.
|
||||
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
|
||||
are organized locally for parallel computation.
|
||||
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
|
||||
distributed setup, important for multi-host environments.
|
||||
|
||||
"""
|
||||
num_replicas = math.prod(between_hosts_config)
|
||||
self.model.initialize()
|
||||
self.model.fprop_dtype = jnp.bfloat16
|
||||
@ -191,12 +311,37 @@ class ModelRunner:
|
||||
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
|
||||
|
||||
def init(self, rng: jax.Array, data) -> TrainingState:
|
||||
"""
|
||||
Initializes the model parameters using provided data and a random number generator state. This method
|
||||
is only called when `transform_forward` is True, indicating that the forward function has been transformed
|
||||
by Haiku and requires explicit initialization.
|
||||
|
||||
Parameters:
|
||||
rng (jax.Array): The random number generator state for initializing model parameters.
|
||||
data: The initial data used for parameter initialization, usually including input tokens and
|
||||
possibly target tokens for supervised learning tasks.
|
||||
|
||||
Returns:
|
||||
TrainingState: An instance of TrainingState containing the initialized model parameters.
|
||||
"""
|
||||
assert self.transform_forward
|
||||
rng, init_rng = jax.random.split(rng)
|
||||
params = self.forward.init(init_rng, data["inputs"])
|
||||
return TrainingState(params=params)
|
||||
|
||||
def get_state_sharding(self, init_data):
|
||||
"""
|
||||
Determines the sharding configuration for model parameters based on initialization data. This method
|
||||
evaluates the shape of model parameters during initialization to apply the partition rules specified
|
||||
in the model's configuration, facilitating efficient distributed computation.
|
||||
|
||||
Parameters:
|
||||
init_data: The initial data used for model initialization, which helps determine the appropriate
|
||||
sharding configuration based on model parameter shapes.
|
||||
|
||||
Returns:
|
||||
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
|
||||
"""
|
||||
assert self.transform_forward
|
||||
rng = jax.random.PRNGKey(self.rng_seed)
|
||||
rank_logger.info(f"partition rules: {self.model.partition_rules}")
|
||||
@ -215,6 +360,20 @@ class ModelRunner:
|
||||
from_checkpoint: bool = True,
|
||||
init_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
|
||||
method provides flexibility in starting the model from a known state or freshly initializing parameters
|
||||
based on the model configuration and provided initial data.
|
||||
|
||||
Parameters:
|
||||
init_data: Initial data used for model parameter initialization if needed.
|
||||
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
|
||||
init_fn (Optional[Callable]): An optional initialization function that overrides the default
|
||||
initialization method if provided.
|
||||
|
||||
Returns:
|
||||
The model state, either loaded from a checkpoint or newly initialized.
|
||||
"""
|
||||
rng = jax.random.PRNGKey(self.rng_seed)
|
||||
|
||||
if not self.checkpoint_path or not from_checkpoint:
|
||||
@ -251,6 +410,16 @@ class ModelRunner:
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
"""
|
||||
Data class for storing information about a single sampling request to the model.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The text prompt to generate text from.
|
||||
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
|
||||
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
|
||||
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
|
||||
max_len (int): The maximum length of the generated text sequence.
|
||||
"""
|
||||
prompt: str
|
||||
temperature: float
|
||||
nucleus_p: float
|
||||
@ -260,6 +429,24 @@ class Request:
|
||||
|
||||
@dataclass
|
||||
class InferenceRunner:
|
||||
"""
|
||||
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
|
||||
prefilling the model memory, and running the actual text generation.
|
||||
|
||||
Attributes:
|
||||
name (str): A name identifier for the inference runner.
|
||||
runner (Any): An instance of ModelRunner that manages the underlying language model.
|
||||
load (str): A path to the model checkpoint for loading.
|
||||
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
|
||||
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
|
||||
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
|
||||
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
|
||||
|
||||
Methods:
|
||||
get_pad_bucket: Determines the appropriate padding size for a given input length.
|
||||
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
|
||||
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
|
||||
"""
|
||||
name: str
|
||||
runner: Any
|
||||
load: str
|
||||
@ -269,10 +456,25 @@ class InferenceRunner:
|
||||
pad_sizes: tuple[int] = (1024,)
|
||||
|
||||
def get_pad_bucket(self, size):
|
||||
"""
|
||||
Determines the appropriate padding size for a given input length. This method uses the predefined
|
||||
pad sizes to find the smallest pad size that is larger than or equal to the given size.
|
||||
|
||||
Parameters:
|
||||
size (int): The length of the input sequence to be padded.
|
||||
|
||||
Returns:
|
||||
int: The selected pad size from the predefined pad sizes that best fits the input length.
|
||||
"""
|
||||
i = bisect.bisect_left(self.pad_sizes, size)
|
||||
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
|
||||
model functions for different input sizes, and loading or initializing model parameters. This method
|
||||
ensures that the inference runner is ready to generate text based on prompts.
|
||||
"""
|
||||
runner = self.runner
|
||||
self.runner.transform_forward = True
|
||||
dummy_data = dict(
|
||||
@ -440,7 +642,15 @@ class InferenceRunner:
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""Generator that accepts prompts."""
|
||||
"""
|
||||
Generator method that accepts prompts and manages the text generation process. This method handles
|
||||
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
|
||||
generate and return the generated text.
|
||||
|
||||
Yields:
|
||||
str: The generated text for each prompt received. This method operates as a coroutine, yielding
|
||||
generated text back to the caller and awaiting new prompts for further generation.
|
||||
"""
|
||||
runner = self.runner
|
||||
mesh = runner.mesh
|
||||
max_len = runner.model.sequence_len
|
||||
@ -580,6 +790,18 @@ class InferenceRunner:
|
||||
def make_mesh(
|
||||
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
|
||||
) -> jax.sharding.Mesh:
|
||||
"""
|
||||
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
|
||||
|
||||
Parameters:
|
||||
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
|
||||
to how local devices are arranged and partitioned for computation.
|
||||
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
|
||||
different hosts, relevant in multi-host distributed setups.
|
||||
|
||||
Returns:
|
||||
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
|
||||
"""
|
||||
assert len(local_mesh_config) == 2
|
||||
assert len(between_hosts_config) == 2
|
||||
rank_logger.info("Detected %s devices in mesh", jax.device_count())
|
||||
@ -594,6 +816,32 @@ def make_mesh(
|
||||
|
||||
|
||||
def sample_from_model(server, prompt, max_len, temperature):
|
||||
"""
|
||||
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
|
||||
to interact with a generator-based server that manages the state of the model and performs the actual
|
||||
sampling. The server is expected to yield control back to this function, which then sends a request
|
||||
object containing the prompt and settings, and waits for the generated output.
|
||||
|
||||
The function initializes the server (if not already done), constructs the request with the provided
|
||||
parameters, and sends this request to the server. The sampled output, typically a continuation of the
|
||||
prompt, is then received from the server.
|
||||
|
||||
Parameters:
|
||||
- server (generator): The generator-based server that handles model inference. This server is expected
|
||||
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
|
||||
performing the token sampling.
|
||||
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
|
||||
encoded and fed into the model as part of the request.
|
||||
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
|
||||
This controls how long the model's output can be.
|
||||
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
|
||||
model's outputs more deterministic (and potentially more repetitive), while higher values increase
|
||||
diversity at the cost of coherence.
|
||||
|
||||
Returns:
|
||||
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
|
||||
by `max_len`.
|
||||
"""
|
||||
next(server)
|
||||
inp = Request(
|
||||
prompt=prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user