mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-26 13:39:52 +03:00
Compare commits
2 Commits
0b0576a0a8
...
52c7b9aa79
Author | SHA1 | Date | |
---|---|---|---|
|
52c7b9aa79 | ||
|
10fd2278a8 |
@ -125,6 +125,14 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
return results
|
||||
|
||||
|
||||
"""
|
||||
For path_tuple_to_string(),
|
||||
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
||||
"""
|
||||
@lru_cache(maxsize=32) # Set maxsize to 32MB
|
||||
def path_tuple_to_string_cached(path: tuple) -> str:
|
||||
return path_tuple_to_string(path=path)
|
||||
|
||||
def path_tuple_to_string(path: tuple) -> str:
|
||||
pieces = []
|
||||
for elem in path:
|
||||
@ -184,13 +192,13 @@ def replace_with_load_state(
|
||||
) -> Any:
|
||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||
load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load}
|
||||
|
||||
replaced = []
|
||||
num_replicas = 1
|
||||
data_model_shards = math.prod(mesh_config)
|
||||
for i, (init_path, tensor) in enumerate(flatten_init):
|
||||
init_path_str = path_tuple_to_string(init_path)
|
||||
init_path_str = path_tuple_to_string_cached(init_path)
|
||||
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
||||
if load_path_str is None:
|
||||
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
||||
|
Loading…
Reference in New Issue
Block a user