This commit is contained in:
Madhav Kumar 2024-03-19 22:18:01 +05:30 committed by GitHub
commit a105de1c16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,10 @@ import tempfile
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional from typing import Any, Optional
# For get_load_path_str
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
from functools import lru_cache
import jax import jax
import numpy as np import numpy as np
from jax.experimental import multihost_utils from jax.experimental import multihost_utils
@ -104,7 +108,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
else: else:
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
wait(fs) wait(fs)
return [f.result() for f in fs]
# return [f.result() for f in fs]
"""
Improve error reporting in load_tensors by catching exceptions within the futures-
and logging detailed information about the failure.
"""
results = []
for future in fs:
try:
result = future.result()
results.append(result)
except Exception as e:
logger.error(f"Failed to load tensor: {e}")
raise
return results
def path_tuple_to_string(path: tuple) -> str: def path_tuple_to_string(path: tuple) -> str:
@ -119,6 +137,22 @@ def path_tuple_to_string(path: tuple) -> str:
return "/".join(pieces) return "/".join(pieces)
"""
For get_load_path_str(),
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
"""
@lru_cache(maxsize=None)
def get_load_path_str_cached(
init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
return get_load_path_str(
init_path_str,
load_rename_rules,
load_exclude_rules
)
def get_load_path_str( def get_load_path_str(
init_path_str: str, init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None, load_rename_rules: Optional[list[tuple[str, str]]] = None,
@ -157,7 +191,7 @@ def replace_with_load_state(
data_model_shards = math.prod(mesh_config) data_model_shards = math.prod(mesh_config)
for i, (init_path, tensor) in enumerate(flatten_init): for i, (init_path, tensor) in enumerate(flatten_init):
init_path_str = path_tuple_to_string(init_path) init_path_str = path_tuple_to_string(init_path)
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
if load_path_str is None: if load_path_str is None:
rank_logger.info(f"Excluded from restore: {init_path_str}.") rank_logger.info(f"Excluded from restore: {init_path_str}.")
replaced.append(tensor) replaced.append(tensor)