From 3fd4e7c4d7ab379329cf019ffba1c292f8428f7a Mon Sep 17 00:00:00 2001 From: Madhav Date: Tue, 19 Mar 2024 21:58:37 +0530 Subject: [PATCH 1/2] Enhanced Error Handling in load_tensors() --- checkpoint.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..4539aa4 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -104,7 +104,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) - return [f.result() for f in fs] + + # return [f.result() for f in fs] + """ + Improve error reporting in load_tensors by catching exceptions within the futures- + and logging detailed information about the failure. + """ + results = [] + for future in fs: + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Failed to load tensor: {e}") + raise + return results def path_tuple_to_string(path: tuple) -> str: From 5fc82399bfe557d3329557b2c52ca4af2c0aec1c Mon Sep 17 00:00:00 2001 From: Madhav Date: Tue, 19 Mar 2024 22:06:40 +0530 Subject: [PATCH 2/2] get_load_path_str() -> get_load_path_str_cached(): Optimizing Regex Operations with Caching --- checkpoint.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 4539aa4..1d81e2e 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -26,6 +26,10 @@ import tempfile from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional +# For get_load_path_str +# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +from functools import lru_cache + import jax import numpy as np from jax.experimental import multihost_utils @@ -133,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str: return "/".join(pieces) +""" +For get_load_path_str(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=None) +def get_load_path_str_cached( + init_path_str: str, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, +) -> Optional[str]: + return get_load_path_str( + init_path_str, + load_rename_rules, + load_exclude_rules + ) + def get_load_path_str( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, @@ -171,7 +191,7 @@ def replace_with_load_state( data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): init_path_str = path_tuple_to_string(init_path) - load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) + load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.") replaced.append(tensor)