mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-23 03:59:53 +03:00
Merge 10fd2278a8
into 7050ed204b
This commit is contained in:
commit
52c7b9aa79
@ -26,6 +26,10 @@ import tempfile
|
|||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
from concurrent.futures import ThreadPoolExecutor, wait
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# For get_load_path_str
|
||||||
|
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.experimental import multihost_utils
|
from jax.experimental import multihost_utils
|
||||||
@ -104,9 +108,31 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
else:
|
else:
|
||||||
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
||||||
wait(fs)
|
wait(fs)
|
||||||
return [f.result() for f in fs]
|
|
||||||
|
# return [f.result() for f in fs]
|
||||||
|
"""
|
||||||
|
Improve error reporting in load_tensors by catching exceptions within the futures-
|
||||||
|
and logging detailed information about the failure.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for future in fs:
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load tensor: {e}")
|
||||||
|
raise
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
For path_tuple_to_string(),
|
||||||
|
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||||
|
"""
|
||||||
|
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||||
|
def path_tuple_to_string_cached(path: tuple) -> str:
|
||||||
|
return path_tuple_to_string(path=path)
|
||||||
|
|
||||||
def path_tuple_to_string(path: tuple) -> str:
|
def path_tuple_to_string(path: tuple) -> str:
|
||||||
pieces = []
|
pieces = []
|
||||||
for elem in path:
|
for elem in path:
|
||||||
@ -119,6 +145,22 @@ def path_tuple_to_string(path: tuple) -> str:
|
|||||||
return "/".join(pieces)
|
return "/".join(pieces)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
For get_load_path_str(),
|
||||||
|
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||||
|
"""
|
||||||
|
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||||
|
def get_load_path_str_cached(
|
||||||
|
init_path_str: str,
|
||||||
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
return get_load_path_str(
|
||||||
|
init_path_str,
|
||||||
|
load_rename_rules,
|
||||||
|
load_exclude_rules
|
||||||
|
)
|
||||||
|
|
||||||
def get_load_path_str(
|
def get_load_path_str(
|
||||||
init_path_str: str,
|
init_path_str: str,
|
||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
@ -150,14 +192,14 @@ def replace_with_load_state(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load}
|
||||||
|
|
||||||
replaced = []
|
replaced = []
|
||||||
num_replicas = 1
|
num_replicas = 1
|
||||||
data_model_shards = math.prod(mesh_config)
|
data_model_shards = math.prod(mesh_config)
|
||||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||||
init_path_str = path_tuple_to_string(init_path)
|
init_path_str = path_tuple_to_string_cached(init_path)
|
||||||
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
||||||
if load_path_str is None:
|
if load_path_str is None:
|
||||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||||
replaced.append(tensor)
|
replaced.append(tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user