From 3f6fb5f4aa76d92cd16f7875c9669e324a6b2d6f Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Mon, 18 Mar 2024 16:05:23 -0500 Subject: [PATCH] Updated docstrings for checkpoint.py --- checkpoint.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..4288f31 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -41,6 +41,17 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit @contextlib.contextmanager def copy_to_shm(file: str): + """ + Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm), + yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path + to the temporary file, and cleans up by removing the temporary file after use. + + Parameters: + - file (str): The path to the file to be copied. + + Yields: + - str: The path to the file in shared memory. + """ if file.startswith("/dev/shm/"): # Nothing to do, the file is already in shared memory. yield file @@ -58,6 +69,17 @@ def copy_to_shm(file: str): @contextlib.contextmanager def copy_from_shm(file: str): + """ + Context manager for copying a file from shared memory to a specified path. It creates a temporary file + in shared memory, yields the path to this temporary file for operations, and then copies the temporary + file to the specified path, cleaning up the temporary file afterwards. + + Parameters: + - file (str): The target path for the file to be copied to from shared memory. + + Yields: + - str: The path to the temporary file in shared memory. + """ tmp_dir = "/dev/shm/" fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) try: @@ -69,19 +91,48 @@ def copy_from_shm(file: str): def fast_unpickle(path: str) -> Any: + """ + Unpickles and loads an object from a file, optionally using shared memory for faster access. + + Parameters: + - path (str): The path to the pickle file to load. + + Returns: + - Any: The object loaded from the pickle file. + """ with copy_to_shm(path) as tmp_path: with open(tmp_path, "rb") as f: return pickle.load(f) def fast_pickle(obj: Any, path: str) -> None: + """ + Pickles and saves an object to a file, optionally using shared memory for faster access. + + Parameters: + - obj (Any): The object to be pickled and saved. + - path (str): The path where the pickle file should be saved. + """ with copy_from_shm(path) as tmp_path: with open(tmp_path, "wb") as f: pickle.dump(obj, f) def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): - """Loads a set of arrays.""" + """ + Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with + arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved + in a separate file. + + Parameters: + - shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load. + - directory (str): The directory from which to load the tensors. + - mesh_config: Configuration for data parallelism across processes or devices. + - tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded. + + Returns: + - List of numpy arrays or zeros depending on the process's role in data parallelism. + """ pool = ThreadPoolExecutor(max_workers=32) fs = list() num_tensors = 0 @@ -108,6 +159,17 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def path_tuple_to_string(path: tuple) -> str: + """ + Converts a tuple representing a path in a nested structure to a string. This function is specifically + used for handling paths within the structure of saved states or parameters, making it easier to + identify and manipulate specific elements. + + Parameters: + - path (tuple): A tuple representing a path in a nested structure. + + Returns: + - str: A string representation of the path, suitable for logging or identification. + """ pieces = [] for elem in path: if isinstance(elem, jax.tree_util.DictKey): @@ -124,6 +186,19 @@ def get_load_path_str( load_rename_rules: Optional[list[tuple[str, str]]] = None, load_exclude_rules: Optional[list[str]] = None, ) -> Optional[str]: + """ + Determines the load path for a given initial path string, applying exclusion and renaming rules. This + function is used in the context of loading states or parameters from files, allowing for flexible + mapping or exclusion based on pattern matching. + + Parameters: + - init_path_str (str): The initial path string to process. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching. + - load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading. + + Returns: + - Optional[str]: The processed load path string, or None if excluded. + """ # Exclusion if load_exclude_rules is not None: for search_pattern in load_exclude_rules: @@ -148,6 +223,21 @@ def replace_with_load_state( load_exclude_rules: Optional[list[str]] = None, mesh_config: tuple = (1, 1), ) -> Any: + """ + Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion + rules. This function supports conditional inclusion and transformation of state elements based on complex + criteria, facilitating flexible state restoration. + + Parameters: + - init_state (Any): The initial state before replacement. + - load_state (Any): The state from which to load replacements. + - load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths. + - load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading. + - mesh_config (tuple): Configuration for data parallelism. + + Returns: + - Any: The initial state with elements replaced from the load state. + """ flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} @@ -186,6 +276,23 @@ def restore( state_sharding, init_state: Optional[Any] = None, ) -> Any: + """ + Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding + configurations. This function is designed for restoring model states from disk with support for distributed + environments, handling the intricacies of partitioning and host-specific configurations. + + Parameters: + - checkpoint_path (str): The path to the checkpoint directory. + - state_shapes (Any): The expected shapes of the state to restore. + - mesh: The mesh configuration for distributed environments. + - between_hosts_config: Configuration for data exchange between hosts. + - params_only (bool): Whether to restore parameters only, excluding other state parts. + - state_sharding: Sharding configuration for the state. + - init_state (Optional[Any]): The initial state to which the checkpoint data is applied. + + Returns: + - Any: The restored state, potentially sharded across the distributed environment. + """ ckpt_path = os.path.join(checkpoint_path, "ckpt-0") rank_logger.info("Loading checkpoint at {}".format(ckpt_path))