This commit is contained in:
Carlos D. Escobar-Valbuena 2024-03-18 16:53:36 -05:00 committed by GitHub
commit 80c6ed8d33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1095 additions and 20 deletions

View File

@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
@contextlib.contextmanager @contextlib.contextmanager
def copy_to_shm(file: str): def copy_to_shm(file: str):
"""
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
to the temporary file, and cleans up by removing the temporary file after use.
Parameters:
- file (str): The path to the file to be copied.
Yields:
- str: The path to the file in shared memory.
"""
if file.startswith("/dev/shm/"): if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory. # Nothing to do, the file is already in shared memory.
yield file yield file
@ -58,6 +69,17 @@ def copy_to_shm(file: str):
@contextlib.contextmanager @contextlib.contextmanager
def copy_from_shm(file: str): def copy_from_shm(file: str):
"""
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
file to the specified path, cleaning up the temporary file afterwards.
Parameters:
- file (str): The target path for the file to be copied to from shared memory.
Yields:
- str: The path to the temporary file in shared memory.
"""
tmp_dir = "/dev/shm/" tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try: try:
@ -69,19 +91,48 @@ def copy_from_shm(file: str):
def fast_unpickle(path: str) -> Any: def fast_unpickle(path: str) -> Any:
"""
Unpickles and loads an object from a file, optionally using shared memory for faster access.
Parameters:
- path (str): The path to the pickle file to load.
Returns:
- Any: The object loaded from the pickle file.
"""
with copy_to_shm(path) as tmp_path: with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f: with open(tmp_path, "rb") as f:
return pickle.load(f) return pickle.load(f)
def fast_pickle(obj: Any, path: str) -> None: def fast_pickle(obj: Any, path: str) -> None:
"""
Pickles and saves an object to a file, optionally using shared memory for faster access.
Parameters:
- obj (Any): The object to be pickled and saved.
- path (str): The path where the pickle file should be saved.
"""
with copy_from_shm(path) as tmp_path: with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f: with open(tmp_path, "wb") as f:
pickle.dump(obj, f) pickle.dump(obj, f)
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays.""" """
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
in a separate file.
Parameters:
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
- directory (str): The directory from which to load the tensors.
- mesh_config: Configuration for data parallelism across processes or devices.
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
Returns:
- List of numpy arrays or zeros depending on the process's role in data parallelism.
"""
pool = ThreadPoolExecutor(max_workers=32) pool = ThreadPoolExecutor(max_workers=32)
fs = list() fs = list()
num_tensors = 0 num_tensors = 0
@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
def path_tuple_to_string(path: tuple) -> str: def path_tuple_to_string(path: tuple) -> str:
"""
Converts a tuple representing a path in a nested structure to a string. This function is specifically
used for handling paths within the structure of saved states or parameters, making it easier to
identify and manipulate specific elements.
Parameters:
- path (tuple): A tuple representing a path in a nested structure.
Returns:
- str: A string representation of the path, suitable for logging or identification.
"""
pieces = [] pieces = []
for elem in path: for elem in path:
if isinstance(elem, jax.tree_util.DictKey): if isinstance(elem, jax.tree_util.DictKey):
@ -124,6 +186,19 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None, load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]: ) -> Optional[str]:
"""
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
function is used in the context of loading states or parameters from files, allowing for flexible
mapping or exclusion based on pattern matching.
Parameters:
- init_path_str (str): The initial path string to process.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
Returns:
- Optional[str]: The processed load path string, or None if excluded.
"""
# Exclusion # Exclusion
if load_exclude_rules is not None: if load_exclude_rules is not None:
for search_pattern in load_exclude_rules: for search_pattern in load_exclude_rules:
@ -148,6 +223,21 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1), mesh_config: tuple = (1, 1),
) -> Any: ) -> Any:
"""
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
rules. This function supports conditional inclusion and transformation of state elements based on complex
criteria, facilitating flexible state restoration.
Parameters:
- init_state (Any): The initial state before replacement.
- load_state (Any): The state from which to load replacements.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
- mesh_config (tuple): Configuration for data parallelism.
Returns:
- Any: The initial state with elements replaced from the load state.
"""
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
@ -186,6 +276,23 @@ def restore(
state_sharding, state_sharding,
init_state: Optional[Any] = None, init_state: Optional[Any] = None,
) -> Any: ) -> Any:
"""
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
configurations. This function is designed for restoring model states from disk with support for distributed
environments, handling the intricacies of partitioning and host-specific configurations.
Parameters:
- checkpoint_path (str): The path to the checkpoint directory.
- state_shapes (Any): The expected shapes of the state to restore.
- mesh: The mesh configuration for distributed environments.
- between_hosts_config: Configuration for data exchange between hosts.
- params_only (bool): Whether to restore parameters only, excluding other state parts.
- state_sharding: Sharding configuration for the state.
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
Returns:
- Any: The restored state, potentially sharded across the distributed environment.
"""
ckpt_path = os.path.join(checkpoint_path, "ckpt-0") ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
rank_logger.info("Loading checkpoint at {}".format(ckpt_path)) rank_logger.info("Loading checkpoint at {}".format(ckpt_path))

733
model.py

File diff suppressed because it is too large Load Diff

21
run.py
View File

@ -22,6 +22,27 @@ CKPT_PATH = "./checkpoints/"
def main(): def main():
"""
Initializes and runs a text generation model using predefined model configurations and inference settings.
This function sets up a language model with specific configurations, including model architecture details
(e.g., embedding sizes, number of layers, attention heads, and MoE settings) and text generation settings
(e.g., vocabulary size, token identifiers). It initializes an inference runner with the model, checkpoint
path, tokenizer, and mesh configuration. The inference runner is then used to generate text based on a
given prompt and output the result.
The process involves:
- Creating a `LanguageModelConfig` instance with specified model parameters, including transformer
configurations and quantization settings for weights.
- Initializing an `InferenceRunner` with the model configuration, batch size per device, checkpoint path,
and other relevant settings.
- Calling the `initialize` method on the inference runner to prepare the model and tokenizer for inference.
- Generating text based on a provided prompt using the `sample_from_model` function, which internally
manages the sampling process through the inference runner.
Output:
- Prints the generated text continuation for a prompt to the standard output.
"""
grok_1_model = LanguageModelConfig( grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024, vocab_size=128 * 1024,
pad_token=0, pad_token=0,

View File

@ -48,6 +48,20 @@ TOP_K = 8
class SampleSettings(NamedTuple): class SampleSettings(NamedTuple):
"""
A NamedTuple for storing settings used during the sampling process.
Attributes:
- temperature (ArrayLike): The temperature controls the randomness of the output. Lower values make
the model outputs more deterministic, while higher values increase diversity.
- nucleus_p (ArrayLike): The nucleus probability controls the cumulative distribution function of
token probabilities, effectively truncating low-probability tokens to focus sampling on a subset
of plausible tokens.
- mask (ArrayLike): A binary mask indicating which tokens are allowed (1) and which are disallowed (0)
for sampling in each batch element.
- active (ArrayLike): A binary indicator for whether a given batch element is actively being used for
generation. Elements not in use are ignored in computations to save resources.
"""
temperature: ArrayLike temperature: ArrayLike
nucleus_p: ArrayLike nucleus_p: ArrayLike
mask: ArrayLike mask: ArrayLike
@ -56,6 +70,15 @@ class SampleSettings(NamedTuple):
class SampleOutput(NamedTuple): class SampleOutput(NamedTuple):
"""
A NamedTuple for storing the output from the sampling process.
Attributes:
- token_id (ArrayLike): The sampled token ID for each batch element.
- prob (ArrayLike): The probability associated with the sampled token for each batch element.
- top_k_token_ids (ArrayLike): The IDs of the top-k most probable tokens for each batch element.
- top_k_probs (ArrayLike): The probabilities associated with the top-k token IDs for each batch element.
"""
token_id: ArrayLike token_id: ArrayLike
prob: ArrayLike prob: ArrayLike
top_k_token_ids: ArrayLike top_k_token_ids: ArrayLike
@ -63,6 +86,19 @@ class SampleOutput(NamedTuple):
def insert_slice(memory: Memory, slice, length, i): def insert_slice(memory: Memory, slice, length, i):
"""
Inserts a slice of KVMemory into a Memory object at a specified index. This function updates the Memory
object with a new slice that includes updated steps for each layer based on the provided length.
Parameters:
- memory (Memory): The original memory object to be updated.
- slice: A slice of the memory to be inserted, typically representing a new or modified piece of information.
- length (int): The step length to set for each KVMemory layer in the inserted slice.
- i (int): The index at which the slice is to be inserted into the memory.
Returns:
- Memory: The updated memory object with the new slice inserted at the specified index.
"""
slice = Memory( slice = Memory(
layers=[ layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length])) KVMemory(layer.k, layer.v, step=jnp.array([length]))
@ -75,6 +111,17 @@ def insert_slice(memory: Memory, slice, length, i):
def pad_to_size(x, size): def pad_to_size(x, size):
"""
Pads or truncates an array to a specified size. If the array is longer than the target size, it will be
truncated from the left. If it is shorter, it will be right-padded with zeros.
Parameters:
- x (ArrayLike): The array to be padded or truncated.
- size (int): The target size for the array.
Returns:
- ArrayLike: The array adjusted to the specified size.
"""
if x.shape[0] > size: if x.shape[0] > size:
# Left truncate if the context is too long. # Left truncate if the context is too long.
x = x[-size:] x = x[-size:]
@ -82,7 +129,20 @@ def pad_to_size(x, size):
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
"""Performs nucleus filtering on logits.""" """
Performs nucleus filtering on logits.
Filters logits to retain only a subset that corresponds to the top-p cumulative probability. This
nucleus filtering method is used to focus the sampling process on a subset of plausible tokens, improving
the quality of generated text.
Parameters:
- logits (jax.Array): The logits to filter.
- top_p (jax.Array): The cumulative probability threshold for filtering logits.
Returns:
- jax.Array: The filtered logits with low-probability tokens set to -inf, effectively removing them
from consideration during sampling.
"""
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits) sorted_probs = jax.nn.softmax(sorted_logits)
@ -102,6 +162,21 @@ def sample_token(
lm_outputs: LanguageModelOutput, lm_outputs: LanguageModelOutput,
settings: SampleSettings, settings: SampleSettings,
) -> SampleOutput: ) -> SampleOutput:
"""
Samples a token from the language model outputs using specified settings, including temperature and
nucleus probability filtering. This function also computes the probability of the sampled token and
identifies the top-k tokens and their probabilities.
Parameters:
- rngs (jax.random.PRNGKey): The random number generator state for sampling.
- lm_outputs (LanguageModelOutput): The outputs from the language model, including logits.
- settings (SampleSettings): The settings controlling the sampling process, including temperature,
nucleus probability, token masking, and active batch elements indicator.
Returns:
- SampleOutput: The results of the sampling process, including the sampled token ID, its probability,
and the top-k tokens and their probabilities for each batch element.
"""
# Expand the settings shape to match the logit shape. # Expand the settings shape to match the logit shape.
settings = SampleSettings( settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
@ -135,6 +210,25 @@ def sample_token(
@dataclass @dataclass
class ModelRunner: class ModelRunner:
"""
Manages the execution of a language model, including initialization, forward passes, and state management.
Attributes:
- model (LanguageModelConfig): The configuration for the language model.
- bs_per_device (float): The batch size per device.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming parameters during model loading.
- load_exclude_rules (Optional[list[str]]): Rules for excluding parameters during model loading.
- rng_seed (int): The initial random number generator seed.
- transform_forward (bool): Flag to determine whether to transform the forward function with Haiku.
- checkpoint_path (str): The path to the directory containing model checkpoints for loading.
Methods:
- make_forward_fn: Creates the forward function for the model, potentially transforming it with Haiku.
- initialize: Initializes the model runner with data and mesh configuration, preparing it for execution.
- init: Initializes the model parameters using provided data.
- get_state_sharding: Determines the sharding configuration for model parameters based on the initialization data.
- load_or_init: Loads the model from a checkpoint or initializes it if the checkpoint is not available or not specified.
"""
model: LanguageModelConfig model: LanguageModelConfig
bs_per_device: float = 2.0 bs_per_device: float = 2.0
@ -148,6 +242,18 @@ class ModelRunner:
checkpoint_path: str = "" checkpoint_path: str = ""
def make_forward_fn(self, mesh: Any): def make_forward_fn(self, mesh: Any):
"""
Creates and optionally transforms the forward function of the model. This method constructs a forward
function that takes input tokens and produces model outputs. If `transform_forward` is set to True, the
method also transforms this function using Haiku's transform method to prepare it for JIT compilation
and execution on the mesh.
Parameters:
mesh (Any): The mesh configuration to be used for distributing the computation across devices.
Returns:
Callable: The forward function, potentially transformed by Haiku, ready for execution.
"""
def forward(tokens): def forward(tokens):
out = self.model.make(mesh=mesh)(tokens) out = self.model.make(mesh=mesh)(tokens)
return out, None return out, None
@ -162,6 +268,20 @@ class ModelRunner:
local_mesh_config: tuple[int, int], local_mesh_config: tuple[int, int],
between_hosts_config: tuple[int, int], between_hosts_config: tuple[int, int],
): ):
"""
Initializes the model runner with necessary configurations and data. This includes setting up the mesh
for distributed computation, calculating batch sizes based on the provided configuration, and preparing
the forward function for execution.
Parameters:
init_data: Initial data used for model initialization. This is typically dummy data matching the
expected input format and dimensions of the model.
local_mesh_config (tuple[int, int]): The configuration of the local mesh, specifying how devices
are organized locally for parallel computation.
between_hosts_config (tuple[int, int]): The configuration for communication between hosts in a
distributed setup, important for multi-host environments.
"""
num_replicas = math.prod(between_hosts_config) num_replicas = math.prod(between_hosts_config)
self.model.initialize() self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16 self.model.fprop_dtype = jnp.bfloat16
@ -191,12 +311,37 @@ class ModelRunner:
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)
def init(self, rng: jax.Array, data) -> TrainingState: def init(self, rng: jax.Array, data) -> TrainingState:
"""
Initializes the model parameters using provided data and a random number generator state. This method
is only called when `transform_forward` is True, indicating that the forward function has been transformed
by Haiku and requires explicit initialization.
Parameters:
rng (jax.Array): The random number generator state for initializing model parameters.
data: The initial data used for parameter initialization, usually including input tokens and
possibly target tokens for supervised learning tasks.
Returns:
TrainingState: An instance of TrainingState containing the initialized model parameters.
"""
assert self.transform_forward assert self.transform_forward
rng, init_rng = jax.random.split(rng) rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data["inputs"]) params = self.forward.init(init_rng, data["inputs"])
return TrainingState(params=params) return TrainingState(params=params)
def get_state_sharding(self, init_data): def get_state_sharding(self, init_data):
"""
Determines the sharding configuration for model parameters based on initialization data. This method
evaluates the shape of model parameters during initialization to apply the partition rules specified
in the model's configuration, facilitating efficient distributed computation.
Parameters:
init_data: The initial data used for model initialization, which helps determine the appropriate
sharding configuration based on model parameter shapes.
Returns:
A sharding configuration that specifies how model parameters should be partitioned across the mesh.
"""
assert self.transform_forward assert self.transform_forward
rng = jax.random.PRNGKey(self.rng_seed) rng = jax.random.PRNGKey(self.rng_seed)
rank_logger.info(f"partition rules: {self.model.partition_rules}") rank_logger.info(f"partition rules: {self.model.partition_rules}")
@ -215,6 +360,20 @@ class ModelRunner:
from_checkpoint: bool = True, from_checkpoint: bool = True,
init_fn: Optional[Callable] = None, init_fn: Optional[Callable] = None,
): ):
"""
Loads model parameters from a checkpoint or initializes them if the checkpoint is not available. This
method provides flexibility in starting the model from a known state or freshly initializing parameters
based on the model configuration and provided initial data.
Parameters:
init_data: Initial data used for model parameter initialization if needed.
from_checkpoint (bool): Flag indicating whether to attempt loading parameters from a checkpoint.
init_fn (Optional[Callable]): An optional initialization function that overrides the default
initialization method if provided.
Returns:
The model state, either loaded from a checkpoint or newly initialized.
"""
rng = jax.random.PRNGKey(self.rng_seed) rng = jax.random.PRNGKey(self.rng_seed)
if not self.checkpoint_path or not from_checkpoint: if not self.checkpoint_path or not from_checkpoint:
@ -251,6 +410,16 @@ class ModelRunner:
@dataclass @dataclass
class Request: class Request:
"""
Data class for storing information about a single sampling request to the model.
Attributes:
prompt (str): The text prompt to generate text from.
temperature (float): Controls the randomness of the output. Lower values result in more deterministic outputs.
nucleus_p (float): The cumulative probability threshold for nucleus sampling, focusing generation on high-probability tokens.
rng_seed (int): Seed for the random number generator to ensure reproducibility of the results.
max_len (int): The maximum length of the generated text sequence.
"""
prompt: str prompt: str
temperature: float temperature: float
nucleus_p: float nucleus_p: float
@ -260,6 +429,24 @@ class Request:
@dataclass @dataclass
class InferenceRunner: class InferenceRunner:
"""
Manages inference operations for generating text based on prompts, including initializing the model and tokenizer,
prefilling the model memory, and running the actual text generation.
Attributes:
name (str): A name identifier for the inference runner.
runner (Any): An instance of ModelRunner that manages the underlying language model.
load (str): A path to the model checkpoint for loading.
tokenizer_path (str): Path to the SentencePiece tokenizer model file.
local_mesh_config (Tuple[int, int]): Configuration for the local device mesh.
between_hosts_config (Tuple[int, int]): Configuration for communication between hosts in a distributed setup.
pad_sizes (tuple[int]): A tuple of padding sizes for processing different lengths of prompts.
Methods:
get_pad_bucket: Determines the appropriate padding size for a given input length.
initialize: Initializes the model runner, tokenizer, and pre-compiles model functions for different input sizes.
run: A generator method that accepts prompts and returns generated text, managing batch slots and sampling steps.
"""
name: str name: str
runner: Any runner: Any
load: str load: str
@ -269,10 +456,25 @@ class InferenceRunner:
pad_sizes: tuple[int] = (1024,) pad_sizes: tuple[int] = (1024,)
def get_pad_bucket(self, size): def get_pad_bucket(self, size):
"""
Determines the appropriate padding size for a given input length. This method uses the predefined
pad sizes to find the smallest pad size that is larger than or equal to the given size.
Parameters:
size (int): The length of the input sequence to be padded.
Returns:
int: The selected pad size from the predefined pad sizes that best fits the input length.
"""
i = bisect.bisect_left(self.pad_sizes, size) i = bisect.bisect_left(self.pad_sizes, size)
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]
def initialize(self): def initialize(self):
"""
Initializes the inference runner by setting up the model runner, loading the tokenizer, pre-compiling
model functions for different input sizes, and loading or initializing model parameters. This method
ensures that the inference runner is ready to generate text based on prompts.
"""
runner = self.runner runner = self.runner
self.runner.transform_forward = True self.runner.transform_forward = True
dummy_data = dict( dummy_data = dict(
@ -440,7 +642,15 @@ class InferenceRunner:
) )
def run(self): def run(self):
"""Generator that accepts prompts.""" """
Generator method that accepts prompts and manages the text generation process. This method handles
tokenization of prompts, setup of sampling settings, and orchestration of the sampling steps to
generate and return the generated text.
Yields:
str: The generated text for each prompt received. This method operates as a coroutine, yielding
generated text back to the caller and awaiting new prompts for further generation.
"""
runner = self.runner runner = self.runner
mesh = runner.mesh mesh = runner.mesh
max_len = runner.model.sequence_len max_len = runner.model.sequence_len
@ -580,6 +790,18 @@ class InferenceRunner:
def make_mesh( def make_mesh(
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
) -> jax.sharding.Mesh: ) -> jax.sharding.Mesh:
"""
Creates a JAX mesh from the provided configuration, which is used for parallel computation across devices.
Parameters:
local_mesh_config (tuple[int, ...]): A tuple specifying the local mesh configuration. This usually corresponds
to how local devices are arranged and partitioned for computation.
between_hosts_config (tuple[int, ...]): A tuple specifying the mesh configuration for communication between
different hosts, relevant in multi-host distributed setups.
Returns:
jax.sharding.Mesh: A mesh object that encapsulates the device topology and partitioning for parallel computation.
"""
assert len(local_mesh_config) == 2 assert len(local_mesh_config) == 2
assert len(between_hosts_config) == 2 assert len(between_hosts_config) == 2
rank_logger.info("Detected %s devices in mesh", jax.device_count()) rank_logger.info("Detected %s devices in mesh", jax.device_count())
@ -594,6 +816,32 @@ def make_mesh(
def sample_from_model(server, prompt, max_len, temperature): def sample_from_model(server, prompt, max_len, temperature):
"""
Samples tokens from a trained model given a prompt and sampling settings. This function is designed
to interact with a generator-based server that manages the state of the model and performs the actual
sampling. The server is expected to yield control back to this function, which then sends a request
object containing the prompt and settings, and waits for the generated output.
The function initializes the server (if not already done), constructs the request with the provided
parameters, and sends this request to the server. The sampled output, typically a continuation of the
prompt, is then received from the server.
Parameters:
- server (generator): The generator-based server that handles model inference. This server is expected
to manage the lifecycle of the model, including loading parameters, setting up the environment, and
performing the token sampling.
- prompt (str): The initial text prompt to which the model generates a continuation. This prompt is
encoded and fed into the model as part of the request.
- max_len (int): The maximum length of the generated sequence, including the length of the prompt.
This controls how long the model's output can be.
- temperature (float): A parameter controlling the randomness of the output. Lower values make the
model's outputs more deterministic (and potentially more repetitive), while higher values increase
diversity at the cost of coherence.
Returns:
- str: The generated text as a continuation of the provided prompt, up to a maximum length specified
by `max_len`.
"""
next(server) next(server)
inp = Request( inp = Request(
prompt=prompt, prompt=prompt,