mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-22 19:49:53 +03:00
Merge 10fd2278a8
into 7050ed204b
This commit is contained in:
commit
52c7b9aa79
@ -26,6 +26,10 @@ import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
# For get_load_path_str
|
||||
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
from functools import lru_cache
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax.experimental import multihost_utils
|
||||
@ -104,9 +108,31 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
else:
|
||||
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
||||
wait(fs)
|
||||
return [f.result() for f in fs]
|
||||
|
||||
# return [f.result() for f in fs]
|
||||
"""
|
||||
Improve error reporting in load_tensors by catching exceptions within the futures-
|
||||
and logging detailed information about the failure.
|
||||
"""
|
||||
results = []
|
||||
for future in fs:
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tensor: {e}")
|
||||
raise
|
||||
return results
|
||||
|
||||
|
||||
"""
|
||||
For path_tuple_to_string(),
|
||||
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
"""
|
||||
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||
def path_tuple_to_string_cached(path: tuple) -> str:
|
||||
return path_tuple_to_string(path=path)
|
||||
|
||||
def path_tuple_to_string(path: tuple) -> str:
|
||||
pieces = []
|
||||
for elem in path:
|
||||
@ -119,6 +145,22 @@ def path_tuple_to_string(path: tuple) -> str:
|
||||
return "/".join(pieces)
|
||||
|
||||
|
||||
"""
|
||||
For get_load_path_str(),
|
||||
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
"""
|
||||
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||
def get_load_path_str_cached(
|
||||
init_path_str: str,
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
) -> Optional[str]:
|
||||
return get_load_path_str(
|
||||
init_path_str,
|
||||
load_rename_rules,
|
||||
load_exclude_rules
|
||||
)
|
||||
|
||||
def get_load_path_str(
|
||||
init_path_str: str,
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
@ -150,14 +192,14 @@ def replace_with_load_state(
|
||||
) -> Any:
|
||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||
load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load}
|
||||
|
||||
replaced = []
|
||||
num_replicas = 1
|
||||
data_model_shards = math.prod(mesh_config)
|
||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||
init_path_str = path_tuple_to_string(init_path)
|
||||
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
||||
init_path_str = path_tuple_to_string_cached(init_path)
|
||||
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
||||
if load_path_str is None:
|
||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||
replaced.append(tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user