2024-03-15 01:03:58 +03:00
|
|
|
# Copyright 2024 X.AI Corp.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import logging
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import re
|
|
|
|
import shutil
|
|
|
|
import sys
|
|
|
|
import tempfile
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
import jax
|
|
|
|
import numpy as np
|
|
|
|
from jax.experimental import multihost_utils
|
|
|
|
|
|
|
|
from model import QuantizedWeight8bit
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
rank_logger = logging.getLogger("rank")
|
|
|
|
|
|
|
|
# Needed for loading the checkpoint with pickle.
|
|
|
|
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def copy_to_shm(file: str):
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
|
|
|
|
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
|
|
|
|
to the temporary file, and cleans up by removing the temporary file after use.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- file (str): The path to the file to be copied.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
- str: The path to the file in shared memory.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
if file.startswith("/dev/shm/"):
|
|
|
|
# Nothing to do, the file is already in shared memory.
|
|
|
|
yield file
|
|
|
|
return
|
|
|
|
|
|
|
|
tmp_dir = "/dev/shm/"
|
|
|
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
|
|
|
try:
|
|
|
|
shutil.copyfile(file, tmp_path)
|
|
|
|
yield tmp_path
|
|
|
|
finally:
|
|
|
|
os.remove(tmp_path)
|
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def copy_from_shm(file: str):
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
|
|
|
|
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
|
|
|
|
file to the specified path, cleaning up the temporary file afterwards.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- file (str): The target path for the file to be copied to from shared memory.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
- str: The path to the temporary file in shared memory.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
tmp_dir = "/dev/shm/"
|
|
|
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
|
|
|
try:
|
|
|
|
yield tmp_path
|
|
|
|
shutil.copyfile(tmp_path, file)
|
|
|
|
finally:
|
|
|
|
os.remove(tmp_path)
|
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
|
|
|
def fast_unpickle(path: str) -> Any:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Unpickles and loads an object from a file, optionally using shared memory for faster access.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- path (str): The path to the pickle file to load.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- Any: The object loaded from the pickle file.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
with copy_to_shm(path) as tmp_path:
|
|
|
|
with open(tmp_path, "rb") as f:
|
|
|
|
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
def fast_pickle(obj: Any, path: str) -> None:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Pickles and saves an object to a file, optionally using shared memory for faster access.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- obj (Any): The object to be pickled and saved.
|
|
|
|
- path (str): The path where the pickle file should be saved.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
with copy_from_shm(path) as tmp_path:
|
|
|
|
with open(tmp_path, "wb") as f:
|
|
|
|
pickle.dump(obj, f)
|
|
|
|
|
|
|
|
|
|
|
|
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
|
|
|
|
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
|
|
|
|
in a separate file.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
|
|
|
|
- directory (str): The directory from which to load the tensors.
|
|
|
|
- mesh_config: Configuration for data parallelism across processes or devices.
|
|
|
|
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- List of numpy arrays or zeros depending on the process's role in data parallelism.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
pool = ThreadPoolExecutor(max_workers=32)
|
|
|
|
fs = list()
|
|
|
|
num_tensors = 0
|
|
|
|
num_replicas = 1
|
|
|
|
data_model_shards = math.prod(mesh_config)
|
|
|
|
if tensor_indices is None:
|
|
|
|
iterator = enumerate(shaped_arrays)
|
|
|
|
else:
|
|
|
|
iterator = zip(tensor_indices, shaped_arrays)
|
|
|
|
for i, t in iterator:
|
|
|
|
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
|
|
|
|
idx = (
|
|
|
|
jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
|
|
|
|
+ jax.process_index() % data_model_shards
|
|
|
|
)
|
|
|
|
fs.append(
|
|
|
|
pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
|
|
|
|
)
|
|
|
|
num_tensors += 1
|
|
|
|
else:
|
|
|
|
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
|
|
|
wait(fs)
|
|
|
|
return [f.result() for f in fs]
|
|
|
|
|
|
|
|
|
|
|
|
def path_tuple_to_string(path: tuple) -> str:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Converts a tuple representing a path in a nested structure to a string. This function is specifically
|
|
|
|
used for handling paths within the structure of saved states or parameters, making it easier to
|
|
|
|
identify and manipulate specific elements.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- path (tuple): A tuple representing a path in a nested structure.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- str: A string representation of the path, suitable for logging or identification.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
pieces = []
|
|
|
|
for elem in path:
|
|
|
|
if isinstance(elem, jax.tree_util.DictKey):
|
|
|
|
pieces.append(elem.key)
|
|
|
|
elif isinstance(elem, jax.tree_util.GetAttrKey):
|
|
|
|
pieces.append(elem.name)
|
|
|
|
else:
|
|
|
|
assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
|
|
|
|
return "/".join(pieces)
|
|
|
|
|
|
|
|
|
|
|
|
def get_load_path_str(
|
|
|
|
init_path_str: str,
|
|
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
|
|
load_exclude_rules: Optional[list[str]] = None,
|
|
|
|
) -> Optional[str]:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
|
|
|
|
function is used in the context of loading states or parameters from files, allowing for flexible
|
|
|
|
mapping or exclusion based on pattern matching.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- init_path_str (str): The initial path string to process.
|
|
|
|
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
|
|
|
|
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- Optional[str]: The processed load path string, or None if excluded.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
# Exclusion
|
|
|
|
if load_exclude_rules is not None:
|
|
|
|
for search_pattern in load_exclude_rules:
|
|
|
|
if re.search(search_pattern, init_path_str):
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Renaming
|
|
|
|
load_path_str = init_path_str
|
|
|
|
if load_rename_rules is not None:
|
|
|
|
for search_pattern, replacement_pattern in load_rename_rules:
|
|
|
|
if re.search(search_pattern, load_path_str):
|
|
|
|
load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
|
|
|
|
break
|
|
|
|
|
|
|
|
return load_path_str
|
|
|
|
|
|
|
|
|
|
|
|
def replace_with_load_state(
|
|
|
|
init_state: Any,
|
|
|
|
load_state: Any,
|
|
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
|
|
load_exclude_rules: Optional[list[str]] = None,
|
|
|
|
mesh_config: tuple = (1, 1),
|
|
|
|
) -> Any:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
|
|
|
|
rules. This function supports conditional inclusion and transformation of state elements based on complex
|
|
|
|
criteria, facilitating flexible state restoration.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- init_state (Any): The initial state before replacement.
|
|
|
|
- load_state (Any): The state from which to load replacements.
|
|
|
|
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
|
|
|
|
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
|
|
|
|
- mesh_config (tuple): Configuration for data parallelism.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- Any: The initial state with elements replaced from the load state.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
|
|
|
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
|
|
|
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
|
|
|
|
|
|
|
replaced = []
|
|
|
|
num_replicas = 1
|
|
|
|
data_model_shards = math.prod(mesh_config)
|
|
|
|
for i, (init_path, tensor) in enumerate(flatten_init):
|
|
|
|
init_path_str = path_tuple_to_string(init_path)
|
|
|
|
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
|
|
|
|
if load_path_str is None:
|
|
|
|
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
|
|
|
replaced.append(tensor)
|
|
|
|
elif load_path_str in load_map:
|
|
|
|
if load_path_str == init_path_str:
|
|
|
|
rank_logger.info(f"Restored from ckpt: {init_path_str}.")
|
|
|
|
else:
|
|
|
|
rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
|
|
|
|
replaced.append(load_map[load_path_str])
|
|
|
|
else:
|
|
|
|
rank_logger.info(f"Not found in ckpt: {init_path_str}.")
|
|
|
|
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
|
|
|
|
replaced.append(tensor)
|
|
|
|
else:
|
|
|
|
replaced.append(np.zeros_like(tensor))
|
|
|
|
|
|
|
|
return jax.tree_util.tree_unflatten(structure_init, replaced)
|
|
|
|
|
|
|
|
|
|
|
|
def restore(
|
|
|
|
checkpoint_path: str,
|
|
|
|
state_shapes: Any,
|
|
|
|
mesh,
|
|
|
|
between_hosts_config,
|
|
|
|
params_only,
|
|
|
|
state_sharding,
|
|
|
|
init_state: Optional[Any] = None,
|
|
|
|
) -> Any:
|
2024-03-19 00:05:23 +03:00
|
|
|
"""
|
|
|
|
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
|
|
|
|
configurations. This function is designed for restoring model states from disk with support for distributed
|
|
|
|
environments, handling the intricacies of partitioning and host-specific configurations.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- checkpoint_path (str): The path to the checkpoint directory.
|
|
|
|
- state_shapes (Any): The expected shapes of the state to restore.
|
|
|
|
- mesh: The mesh configuration for distributed environments.
|
|
|
|
- between_hosts_config: Configuration for data exchange between hosts.
|
|
|
|
- params_only (bool): Whether to restore parameters only, excluding other state parts.
|
|
|
|
- state_sharding: Sharding configuration for the state.
|
|
|
|
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- Any: The restored state, potentially sharded across the distributed environment.
|
|
|
|
"""
|
2024-03-15 01:03:58 +03:00
|
|
|
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
|
|
|
|
|
|
|
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
|
|
|
ckpt_shapes = state_shapes
|
|
|
|
ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
|
|
|
|
|
|
|
|
ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
|
|
|
|
loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
|
|
|
|
|
|
|
|
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
|
|
|
|
|
|
|
|
# Sanity check to give a better error message.
|
|
|
|
ckpt_keys = set(state.params.keys())
|
|
|
|
code_keys = set(state_sharding.params.keys())
|
|
|
|
|
|
|
|
if ckpt_keys != code_keys and init_state is None:
|
|
|
|
missing_in_ckpt = code_keys - ckpt_keys
|
|
|
|
missing_locally = ckpt_keys - code_keys
|
|
|
|
raise ValueError(
|
|
|
|
"Parameters in the code are not matching checkpoint parameters.\n"
|
|
|
|
"Params missing in checkpoint: {}\nParams missing in code: {}".format(
|
|
|
|
missing_in_ckpt, missing_locally
|
|
|
|
)
|
|
|
|
)
|
|
|
|
state_sharding = jax.tree_util.tree_map(
|
|
|
|
lambda x: jax.sharding.PartitionSpec() if x is None else x,
|
|
|
|
state_sharding,
|
|
|
|
is_leaf=lambda x: x is None,
|
|
|
|
)
|
|
|
|
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
|
|
|
|
if params_only:
|
|
|
|
state = state.params
|
|
|
|
return state
|