diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..c869d94 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -26,6 +26,10 @@ import tempfile from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional +# For get_load_path_str +# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +from functools import lru_cache + import jax import numpy as np from jax.experimental import multihost_utils @@ -104,9 +108,31 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) - return [f.result() for f in fs] + + # return [f.result() for f in fs] + """ + Improve error reporting in load_tensors by catching exceptions within the futures- + and logging detailed information about the failure. + """ + results = [] + for future in fs: + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Failed to load tensor: {e}") + raise + return results +""" +For path_tuple_to_string(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def path_tuple_to_string_cached(path: tuple) -> str: + return path_tuple_to_string(path=path) + def path_tuple_to_string(path: tuple) -> str: pieces = [] for elem in path: @@ -119,6 +145,22 @@ def path_tuple_to_string(path: tuple) -> str: return "/".join(pieces) +""" +For get_load_path_str(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def get_load_path_str_cached( + init_path_str: str, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, +) -> Optional[str]: + return get_load_path_str( + init_path_str, + load_rename_rules, + load_exclude_rules + ) + def get_load_path_str( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, @@ -150,14 +192,14 @@ def replace_with_load_state( ) -> Any: flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) - load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} + load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load} replaced = [] num_replicas = 1 data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): - init_path_str = path_tuple_to_string(init_path) - load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) + init_path_str = path_tuple_to_string_cached(init_path) + load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.") replaced.append(tensor)