Update checkpoint.py

This commit is contained in:
San 2024-04-04 12:17:34 +03:00
parent 7d46bbbdcc
commit aa2be03aee

View File

@ -39,8 +39,19 @@ rank_logger = logging.getLogger("rank")
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
# Utility functions for file handling and shared memory
@contextlib.contextmanager
def copy_to_shm(file: str):
"""
Context manager to copy a file to shared memory.
Args:
file (str): The path to the file to be copied.
Yields:
str: The path to the copied file in shared memory.
"""
if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file
@ -58,6 +69,15 @@ def copy_to_shm(file: str):
@contextlib.contextmanager
def copy_from_shm(file: str):
"""
Context manager to copy a file from shared memory.
Args:
file (str): The path to the file to be copied.
Yields:
str: The path to the temporary file in shared memory.
"""
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
@ -69,19 +89,48 @@ def copy_from_shm(file: str):
def fast_unpickle(path: str) -> Any:
"""
Unpickle an object from a file using shared memory for faster loading.
Args:
path (str): The path to the file containing the pickled object.
Returns:
Any: The unpickled object.
"""
with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f:
return pickle.load(f)
def fast_pickle(obj: Any, path: str) -> None:
"""
Pickle an object to a file using shared memory for faster saving.
Args:
obj (Any): The object to be pickled.
path (str): The path to the file where the object will be saved.
"""
with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f:
pickle.dump(obj, f)
# Tensor loading and path handling
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays."""
"""
Load a set of arrays from files in parallel using a thread pool.
Args:
shaped_arrays (list): A list of shaped arrays to be loaded.
directory (str): The directory containing the tensor files.
mesh_config (tuple): The mesh configuration.
tensor_indices (list, optional): The indices of the tensors to load. Defaults to None.
Returns:
list: A list of loaded arrays.
"""
pool = ThreadPoolExecutor(max_workers=32)
fs = list()
num_tensors = 0
@ -108,6 +157,15 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
def path_tuple_to_string(path: tuple) -> str:
"""
Convert a path tuple to a string representation.
Args:
path (tuple): The path tuple.
Returns:
str: The string representation of the path.
"""
pieces = []
for elem in path:
if isinstance(elem, jax.tree_util.DictKey):
@ -124,6 +182,17 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
"""
Get the load path string based on the initial path string and renaming/exclusion rules.
Args:
init_path_str (str): The initial path string.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
Returns:
Optional[str]: The load path string if not excluded, otherwise None.
"""
# Exclusion
if load_exclude_rules is not None:
for search_pattern in load_exclude_rules:
@ -148,6 +217,19 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1),
) -> Any:
"""
Replace the initial state with the loaded state based on renaming and exclusion rules.
Args:
init_state (Any): The initial state.
load_state (Any): The loaded state.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
mesh_config (tuple, optional): The mesh configuration. Defaults to (1, 1).
Returns:
Any: The replaced state.
"""
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
@ -177,6 +259,8 @@ def replace_with_load_state(
return jax.tree_util.tree_unflatten(structure_init, replaced)
# Checkpoint restoration
def restore(
checkpoint_path: str,
state_shapes: Any,
@ -186,6 +270,21 @@ def restore(
state_sharding,
init_state: Optional[Any] = None,
) -> Any:
"""
Restore the state from a checkpoint.
Args:
checkpoint_path (str): The path to the checkpoint directory.
state_shapes (Any): The shapes of the state.
mesh: The mesh configuration.
between_hosts_config: The configuration for communication between hosts.
params_only (bool): Whether to restore only the parameters.
state_sharding: The sharding specification for the state.
init_state (Optional[Any], optional): The initial state. Defaults to None.
Returns:
Any: The restored state.
"""
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))