From 10fd2278a864486233c166298e523b6c5ca5f0d8 Mon Sep 17 00:00:00 2001 From: Madhav Date: Sat, 23 Mar 2024 15:37:18 +0530 Subject: [PATCH] path_tuple_to_string() -> path_tuple_to_string_cached(): Optimizing Regex Operations with Caching --- checkpoint.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/checkpoint.py b/checkpoint.py index 2b08b42..c869d94 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -125,6 +125,14 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): return results +""" +For path_tuple_to_string(), +introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed. +""" +@lru_cache(maxsize=32) # Set maxsize to 32MB +def path_tuple_to_string_cached(path: tuple) -> str: + return path_tuple_to_string(path=path) + def path_tuple_to_string(path: tuple) -> str: pieces = [] for elem in path: @@ -184,13 +192,13 @@ def replace_with_load_state( ) -> Any: flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) - load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} + load_map = {path_tuple_to_string_cached(path): tensor for path, tensor in flatten_load} replaced = [] num_replicas = 1 data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): - init_path_str = path_tuple_to_string(init_path) + init_path_str = path_tuple_to_string_cached(init_path) load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.")