diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..2b08b42 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -26,6 +26,10 @@ import tempfile from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional +# For get_load_path_str +# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +from functools import lru_cache + import jax import numpy as np from jax.experimental import multihost_utils @@ -104,7 +108,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) - return [f.result() for f in fs] + + # return [f.result() for f in fs] + """ + Improve error reporting in load_tensors by catching exceptions within the futures- + and logging detailed information about the failure. + """ + results = [] + for future in fs: + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Failed to load tensor: {e}") + raise + return results def path_tuple_to_string(path: tuple) -> str: @@ -119,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str: return "/".join(pieces) +""" +For get_load_path_str(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def get_load_path_str_cached( + init_path_str: str, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, +) -> Optional[str]: + return get_load_path_str( + init_path_str, + load_rename_rules, + load_exclude_rules + ) + def get_load_path_str( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, @@ -157,7 +191,7 @@ def replace_with_load_state( data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): init_path_str = path_tuple_to_string(init_path) - load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) + load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.") replaced.append(tensor)