From 3fd4e7c4d7ab379329cf019ffba1c292f8428f7a Mon Sep 17 00:00:00 2001 From: Madhav Date: Tue, 19 Mar 2024 21:58:37 +0530 Subject: [PATCH 1/4] Enhanced Error Handling in load_tensors() --- checkpoint.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..4539aa4 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -104,7 +104,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) - return [f.result() for f in fs] + + # return [f.result() for f in fs] + """ + Improve error reporting in load_tensors by catching exceptions within the futures- + and logging detailed information about the failure. + """ + results = [] + for future in fs: + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Failed to load tensor: {e}") + raise + return results def path_tuple_to_string(path: tuple) -> str: From 5fc82399bfe557d3329557b2c52ca4af2c0aec1c Mon Sep 17 00:00:00 2001 From: Madhav Date: Tue, 19 Mar 2024 22:06:40 +0530 Subject: [PATCH 2/4] get_load_path_str() -> get_load_path_str_cached(): Optimizing Regex Operations with Caching --- checkpoint.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 4539aa4..1d81e2e 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -26,6 +26,10 @@ import tempfile from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional +# For get_load_path_str +# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +from functools import lru_cache + import jax import numpy as np from jax.experimental import multihost_utils @@ -133,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str: return "/".join(pieces) +""" +For get_load_path_str(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=None) +def get_load_path_str_cached( + init_path_str: str, + load_rename_rules: Optional[list[tuple[str, str]]] = None, + load_exclude_rules: Optional[list[str]] = None, +) -> Optional[str]: + return get_load_path_str( + init_path_str, + load_rename_rules, + load_exclude_rules + ) + def get_load_path_str( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, @@ -171,7 +191,7 @@ def replace_with_load_state( data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): init_path_str = path_tuple_to_string(init_path) - load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) + load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.") replaced.append(tensor) From 104affb51a8c0db24bc7b1a5ed3489a702164277 Mon Sep 17 00:00:00 2001 From: Madhav Kumar <78339236+Madhav-MKNC@users.noreply.github.com> Date: Thu, 21 Mar 2024 06:48:22 +0530 Subject: [PATCH 3/4] Set maxsize to 32MB in lru_cache() for reasonable memory usage --- checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkpoint.py b/checkpoint.py index 1d81e2e..2b08b42 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -141,7 +141,7 @@ def path_tuple_to_string(path: tuple) -> str: For get_load_path_str(), introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. """ -@lru_cache(maxsize=None) +@lru_cache(maxsize=32) # Set maxsize to 32MB def get_load_path_str_cached( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, From 10fd2278a864486233c166298e523b6c5ca5f0d8 Mon Sep 17 00:00:00 2001 From: Madhav Date: Sat, 23 Mar 2024 15:37:18 +0530 Subject: [PATCH 4/4] path_tuple_to_string() -> path_tuple_to_string_cached(): Optimizing Regex Operations with Caching --- checkpoint.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/checkpoint.py b/checkpoint.py index 2b08b42..c869d94 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -125,6 +125,14 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): return results +""" +For path_tuple_to_string(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def path_tuple_to_string_cached(path: tuple) -> str: + return path_tuple_to_string(path=path) + def path_tuple_to_string(path: tuple) -> str: pieces = [] for elem in path: @@ -184,13 +192,13 @@ def replace_with_load_state( ) -> Any: flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) - load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} + load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load} replaced = [] num_replicas = 1 data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): - init_path_str = path_tuple_to_string(init_path) + init_path_str = path_tuple_to_string_cached(init_path) load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.")