diff --git a/checkpoint.py b/checkpoint.py index 2b08b42..c869d94 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -125,6 +125,14 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): return results +""" +For path_tuple_to_string(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def path_tuple_to_string_cached(path: tuple) -> str: + return path_tuple_to_string(path=path) + def path_tuple_to_string(path: tuple) -> str: pieces = [] for elem in path: @@ -184,13 +192,13 @@ def replace_with_load_state( ) -> Any: flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) - load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} + load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load} replaced = [] num_replicas = 1 data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): - init_path_str = path_tuple_to_string(init_path) + init_path_str = path_tuple_to_string_cached(init_path) load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.")