Updated docstrings for runners.py

This commit is contained in:
Carlos D. Escobar-Valbuena 2024-03-18 16:15:48 -05:00 committed by GitHub
parent 3f6fb5f4aa
commit 2d92966a1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,6 +48,20 @@ TOP_K = 8
class SampleSettings(NamedTuple): class SampleSettings(NamedTuple):
"""
A NamedTuple for storing settings used during the sampling process.
Attributes:
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
the model outputs more deterministic, while higher values increase diversity.
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
of plausible tokens.
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
for sampling in each batch element.
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
generation. Elements not in use are ignored in computations to save resources.
"""
temperature: ArrayLike temperature: ArrayLike
nucleus_p: ArrayLike nucleus_p: ArrayLike
mask: ArrayLike mask: ArrayLike
@ -56,6 +70,15 @@ class SampleSettings(NamedTuple):
class SampleOutput(NamedTuple): class SampleOutput(NamedTuple):
"""
A NamedTuple for storing the output from the sampling process.
Attributes:
- token_id (ArrayLike): The sampled token ID for each batch element.
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
"""
token_id: ArrayLike token_id: ArrayLike
prob: ArrayLike prob: ArrayLike
top_k_token_ids: ArrayLike top_k_token_ids: ArrayLike
@ -63,6 +86,19 @@ class SampleOutput(NamedTuple):
def insert_slice(memory: Memory, slice, length, i): def insert_slice(memory: Memory, slice, length, i):
"""
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
object with a new slice that includes updated steps for each layer based on the provided length.
Parameters:
- memory (Memory): The original memory object to be updated.
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
- length (int): The step length to set for each KVMemory layer in the inserted slice.
- i (int): The index at which the slice is to be inserted into the memory.
Returns:
- Memory: The updated memory object with the new slice inserted at the specified index.
"""
slice = Memory( slice = Memory(
layers=[ layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length])) KVMemory(layer.k, layer.v, step=jnp.array([length]))
@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i):
def pad_to_size(x, size): def pad_to_size(x, size):
"""
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
truncated from the left. If it is shorter, it will be right-padded with zeros.
Parameters:
- x (ArrayLike): The array to be padded or truncated.
- size (int): The target size for the array.
Returns:
- ArrayLike: The array adjusted to the specified size.
"""
if x.shape[0] > size: if x.shape[0] > size:
# Left truncate if the context is too long. # Left truncate if the context is too long.
x = x[-size:] x = x[-size:]
@ -82,7 +129,20 @@ def pad_to_size(x, size):
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
"""Performs nucleus filtering on logits.""" """
Performs nucleus filtering on logits.
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
the quality of generated text.
Parameters:
- logits (jax.Array): The logits to filter.
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
Returns:
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
from consideration during sampling.
"""
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits) sorted_probs = jax.nn.softmax(sorted_logits)
@ -102,6 +162,21 @@ def sample_token(
lm_outputs: LanguageModelOutput, lm_outputs: LanguageModelOutput,
settings: SampleSettings, settings: SampleSettings,
) -> SampleOutput: ) -> SampleOutput:
"""
Samples a token from the language model outputs using specified settings, including temperature and
nucleus probability filtering. This function also computes the probability of the sampled token and
identifies the top-k tokens and their probabilities.
Parameters:
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
nucleus probability, token masking, and active batch elements indicator.
Returns:
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
and the top-k tokens and their probabilities for each batch element.
"""
# Expand the settings shape to match the logit shape. # Expand the settings shape to match the logit shape.
settings = SampleSettings( settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
@ -135,6 +210,25 @@ def sample_token(
@dataclass @dataclass
class ModelRunner: class ModelRunner:
"""
Manages the execution of a language model, including initialization, forward passes, and state management.
Attributes:
- model (LanguageModelConfig): The configuration for the language model.
- bs_per_device (float): The batch size per device.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
- rng_seed (int): The initial random number generator seed.
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
Methods:
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
- init: Initializes the model parameters using provided data.
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
"""
model: LanguageModelConfig model: LanguageModelConfig
bs_per_device: float = 2.0 bs_per_device: float = 2.0
@ -148,6 +242,18 @@ class ModelRunner:
checkpoint_path: str = "" checkpoint_path: str = ""
def make_forward_fn(self, mesh: Any): def make_forward_fn(self, mesh: Any):
"""
Creates and optionally transforms the forward function of the model. This method constructs a forward
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
and execution on the mesh.
Parameters:
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
Returns:
Callable: The forward function, potentially transformed by Haiku, ready for execution.
"""
def forward(tokens): def forward(tokens):
out = self.model.make(mesh=mesh)(tokens) out = self.model.make(mesh=mesh)(tokens)
return out, None return out, None
@ -162,6 +268,20 @@ class ModelRunner:
local_mesh_config: tuple[int, int], local_mesh_config: tuple[int, int],
between_hosts_config: tuple[int, int], between_hosts_config: tuple[int, int],
): ):
"""
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
the forward function for execution.
Parameters:
init_data: Initial data used for model initialization. This is typically dummy data matching the
expected input format and dimensions of the model.
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
are organized locally for parallel computation.
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
distributed setup, important for multi-host environments.
"""
num_replicas = math.prod(between_hosts_config) num_replicas = math.prod(between_hosts_config)
self.model.initialize() self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16 self.model.fprop_dtype = jnp.bfloat16
@ -191,12 +311,37 @@ class ModelRunner:
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
def init(self, rng: jax.Array, data) -> TrainingState: def init(self, rng: jax.Array, data) -> TrainingState:
"""
Initializes the model parameters using provided data and a random number generator state. This method
is only called when `transform_forward` is True, indicating that the forward function has been transformed
by Haiku and requires explicit initialization.
Parameters:
rng (jax.Array): The random number generator state for initializing model parameters.
data: The initial data used for parameter initialization, usually including input tokens and
possibly target tokens for supervised learning tasks.
Returns:
TrainingState: An instance of TrainingState containing the initialized model parameters.
"""
assert self.transform_forward assert self.transform_forward
rng, init_rng = jax.random.split(rng) rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data["inputs"]) params = self.forward.init(init_rng, data["inputs"])
return TrainingState(params=params) return TrainingState(params=params)
def get_state_sharding(self, init_data): def get_state_sharding(self, init_data):
"""
Determines the sharding configuration for model parameters based on initialization data. This method
evaluates the shape of model parameters during initialization to apply the partition rules specified
in the model's configuration, facilitating efficient distributed computation.
Parameters:
init_data: The initial data used for model initialization, which helps determine the appropriate
sharding configuration based on model parameter shapes.
Returns:
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
"""
assert self.transform_forward assert self.transform_forward
rng = jax.random.PRNGKey(self.rng_seed) rng = jax.random.PRNGKey(self.rng_seed)
rank_logger.info(f"partition rules: {self.model.partition_rules}") rank_logger.info(f"partition rules: {self.model.partition_rules}")
@ -215,6 +360,20 @@ class ModelRunner:
from_checkpoint: bool = True, from_checkpoint: bool = True,
init_fn: Optional[Callable] = None, init_fn: Optional[Callable] = None,
): ):
"""
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
method provides flexibility in starting the model from a known state or freshly initializing parameters
based on the model configuration and provided initial data.
Parameters:
init_data: Initial data used for model parameter initialization if needed.
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
init_fn (Optional[Callable]): An optional initialization function that overrides the default
initialization method if provided.
Returns:
The model state, either loaded from a checkpoint or newly initialized.
"""
rng = jax.random.PRNGKey(self.rng_seed) rng = jax.random.PRNGKey(self.rng_seed)
if not self.checkpoint_path or not from_checkpoint: if not self.checkpoint_path or not from_checkpoint:
@ -251,6 +410,16 @@ class ModelRunner:
@dataclass @dataclass
class Request: class Request:
"""
Data class for storing information about a single sampling request to the model.
Attributes:
prompt (str): The text prompt to generate text from.
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
max_len (int): The maximum length of the generated text sequence.
"""
prompt: str prompt: str
temperature: float temperature: float
nucleus_p: float nucleus_p: float
@ -260,6 +429,24 @@ class Request:
@dataclass @dataclass
class InferenceRunner: class InferenceRunner:
"""
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
prefilling the model memory, and running the actual text generation.
Attributes:
name (str): A name identifier for the inference runner.
runner (Any): An instance of ModelRunner that manages the underlying language model.
load (str): A path to the model checkpoint for loading.
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
Methods:
get_pad_bucket: Determines the appropriate padding size for a given input length.
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
"""
name: str name: str
runner: Any runner: Any
load: str load: str
@ -269,10 +456,25 @@ class InferenceRunner:
pad_sizes: tuple[int] = (1024,) pad_sizes: tuple[int] = (1024,)
def get_pad_bucket(self, size): def get_pad_bucket(self, size):
"""
Determines the appropriate padding size for a given input length. This method uses the predefined
pad sizes to find the smallest pad size that is larger than or equal to the given size.
Parameters:
size (int): The length of the input sequence to be padded.
Returns:
int: The selected pad size from the predefined pad sizes that best fits the input length.
"""
i = bisect.bisect_left(self.pad_sizes, size) i = bisect.bisect_left(self.pad_sizes, size)
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
def initialize(self): def initialize(self):
"""
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
model functions for different input sizes, and loading or initializing model parameters. This method
ensures that the inference runner is ready to generate text based on prompts.
"""
runner = self.runner runner = self.runner
self.runner.transform_forward = True self.runner.transform_forward = True
dummy_data = dict( dummy_data = dict(
@ -440,7 +642,15 @@ class InferenceRunner:
) )
def run(self): def run(self):
"""Generator that accepts prompts.""" """
Generator method that accepts prompts and manages the text generation process. This method handles
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
generate and return the generated text.
Yields:
str: The generated text for each prompt received. This method operates as a coroutine, yielding
generated text back to the caller and awaiting new prompts for further generation.
"""
runner = self.runner runner = self.runner
mesh = runner.mesh mesh = runner.mesh
max_len = runner.model.sequence_len max_len = runner.model.sequence_len
@ -580,6 +790,18 @@ class InferenceRunner:
def make_mesh( def make_mesh(
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
) -> jax.sharding.Mesh: ) -> jax.sharding.Mesh:
"""
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
Parameters:
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
to how local devices are arranged and partitioned for computation.
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
different hosts, relevant in multi-host distributed setups.
Returns:
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
"""
assert len(local_mesh_config) == 2 assert len(local_mesh_config) == 2
assert len(between_hosts_config) == 2 assert len(between_hosts_config) == 2
rank_logger.info("Detected %s devices in mesh", jax.device_count()) rank_logger.info("Detected %s devices in mesh", jax.device_count())
@ -594,6 +816,32 @@ def make_mesh(
def sample_from_model(server, prompt, max_len, temperature): def sample_from_model(server, prompt, max_len, temperature):
"""
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
to interact with a generator-based server that manages the state of the model and performs the actual
sampling. The server is expected to yield control back to this function, which then sends a request
object containing the prompt and settings, and waits for the generated output.
The function initializes the server (if not already done), constructs the request with the provided
parameters, and sends this request to the server. The sampled output, typically a continuation of the
prompt, is then received from the server.
Parameters:
- server (generator): The generator-based server that handles model inference. This server is expected
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
performing the token sampling.
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
encoded and fed into the model as part of the request.
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
This controls how long the model's output can be.
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
model's outputs more deterministic (and potentially more repetitive), while higher values increase
diversity at the cost of coherence.
Returns:
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
by `max_len`.
"""
next(server) next(server)
inp = Request( inp = Request(
prompt=prompt, prompt=prompt,