mirror of
https://github.com/xai-org/grok-1.git
synced 2025-02-21 13:59:59 +03:00
Updated docstrings for checkpoint.py
This commit is contained in:
parent
429a83e5d9
commit
3f6fb5f4aa
109
checkpoint.py
109
checkpoint.py
@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
||||
|
||||
@contextlib.contextmanager
|
||||
def copy_to_shm(file: str):
|
||||
"""
|
||||
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
|
||||
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
|
||||
to the temporary file, and cleans up by removing the temporary file after use.
|
||||
|
||||
Parameters:
|
||||
- file (str): The path to the file to be copied.
|
||||
|
||||
Yields:
|
||||
- str: The path to the file in shared memory.
|
||||
"""
|
||||
if file.startswith("/dev/shm/"):
|
||||
# Nothing to do, the file is already in shared memory.
|
||||
yield file
|
||||
@ -58,6 +69,17 @@ def copy_to_shm(file: str):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def copy_from_shm(file: str):
|
||||
"""
|
||||
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
|
||||
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
|
||||
file to the specified path, cleaning up the temporary file afterwards.
|
||||
|
||||
Parameters:
|
||||
- file (str): The target path for the file to be copied to from shared memory.
|
||||
|
||||
Yields:
|
||||
- str: The path to the temporary file in shared memory.
|
||||
"""
|
||||
tmp_dir = "/dev/shm/"
|
||||
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
||||
try:
|
||||
@ -69,19 +91,48 @@ def copy_from_shm(file: str):
|
||||
|
||||
|
||||
def fast_unpickle(path: str) -> Any:
|
||||
"""
|
||||
Unpickles and loads an object from a file, optionally using shared memory for faster access.
|
||||
|
||||
Parameters:
|
||||
- path (str): The path to the pickle file to load.
|
||||
|
||||
Returns:
|
||||
- Any: The object loaded from the pickle file.
|
||||
"""
|
||||
with copy_to_shm(path) as tmp_path:
|
||||
with open(tmp_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def fast_pickle(obj: Any, path: str) -> None:
|
||||
"""
|
||||
Pickles and saves an object to a file, optionally using shared memory for faster access.
|
||||
|
||||
Parameters:
|
||||
- obj (Any): The object to be pickled and saved.
|
||||
- path (str): The path where the pickle file should be saved.
|
||||
"""
|
||||
with copy_from_shm(path) as tmp_path:
|
||||
with open(tmp_path, "wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
"""Loads a set of arrays."""
|
||||
"""
|
||||
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
|
||||
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
|
||||
in a separate file.
|
||||
|
||||
Parameters:
|
||||
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
|
||||
- directory (str): The directory from which to load the tensors.
|
||||
- mesh_config: Configuration for data parallelism across processes or devices.
|
||||
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
|
||||
|
||||
Returns:
|
||||
- List of numpy arrays or zeros depending on the process's role in data parallelism.
|
||||
"""
|
||||
pool = ThreadPoolExecutor(max_workers=32)
|
||||
fs = list()
|
||||
num_tensors = 0
|
||||
@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
|
||||
|
||||
def path_tuple_to_string(path: tuple) -> str:
|
||||
"""
|
||||
Converts a tuple representing a path in a nested structure to a string. This function is specifically
|
||||
used for handling paths within the structure of saved states or parameters, making it easier to
|
||||
identify and manipulate specific elements.
|
||||
|
||||
Parameters:
|
||||
- path (tuple): A tuple representing a path in a nested structure.
|
||||
|
||||
Returns:
|
||||
- str: A string representation of the path, suitable for logging or identification.
|
||||
"""
|
||||
pieces = []
|
||||
for elem in path:
|
||||
if isinstance(elem, jax.tree_util.DictKey):
|
||||
@ -124,6 +186,19 @@ def get_load_path_str(
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
|
||||
function is used in the context of loading states or parameters from files, allowing for flexible
|
||||
mapping or exclusion based on pattern matching.
|
||||
|
||||
Parameters:
|
||||
- init_path_str (str): The initial path string to process.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
|
||||
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||
|
||||
Returns:
|
||||
- Optional[str]: The processed load path string, or None if excluded.
|
||||
"""
|
||||
# Exclusion
|
||||
if load_exclude_rules is not None:
|
||||
for search_pattern in load_exclude_rules:
|
||||
@ -148,6 +223,21 @@ def replace_with_load_state(
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
mesh_config: tuple = (1, 1),
|
||||
) -> Any:
|
||||
"""
|
||||
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
|
||||
rules. This function supports conditional inclusion and transformation of state elements based on complex
|
||||
criteria, facilitating flexible state restoration.
|
||||
|
||||
Parameters:
|
||||
- init_state (Any): The initial state before replacement.
|
||||
- load_state (Any): The state from which to load replacements.
|
||||
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
|
||||
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
||||
- mesh_config (tuple): Configuration for data parallelism.
|
||||
|
||||
Returns:
|
||||
- Any: The initial state with elements replaced from the load state.
|
||||
"""
|
||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||
@ -186,6 +276,23 @@ def restore(
|
||||
state_sharding,
|
||||
init_state: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
|
||||
configurations. This function is designed for restoring model states from disk with support for distributed
|
||||
environments, handling the intricacies of partitioning and host-specific configurations.
|
||||
|
||||
Parameters:
|
||||
- checkpoint_path (str): The path to the checkpoint directory.
|
||||
- state_shapes (Any): The expected shapes of the state to restore.
|
||||
- mesh: The mesh configuration for distributed environments.
|
||||
- between_hosts_config: Configuration for data exchange between hosts.
|
||||
- params_only (bool): Whether to restore parameters only, excluding other state parts.
|
||||
- state_sharding: Sharding configuration for the state.
|
||||
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
|
||||
|
||||
Returns:
|
||||
- Any: The restored state, potentially sharded across the distributed environment.
|
||||
"""
|
||||
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
||||
|
||||
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
||||
|
Loading…
Reference in New Issue
Block a user