grok-1/checkpoint.py
2024-03-18 16:05:23 -05:00

329 lines
12 KiB
Python

# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import logging
import math
import os
import pickle
import re
import shutil
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
import jax
import numpy as np
from jax.experimental import multihost_utils
from model import QuantizedWeight8bit
logger = logging.getLogger(__name__)
rank_logger = logging.getLogger("rank")
# Needed for loading the checkpoint with pickle.
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
@contextlib.contextmanager
def copy_to_shm(file: str):
"""
Context manager for copying a file to shared memory. If the file is already in shared memory (/dev/shm),
yields the same file path. Otherwise, copies the file to a temporary file in shared memory, yields the path
to the temporary file, and cleans up by removing the temporary file after use.
Parameters:
- file (str): The path to the file to be copied.
Yields:
- str: The path to the file in shared memory.
"""
if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file
return
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
shutil.copyfile(file, tmp_path)
yield tmp_path
finally:
os.remove(tmp_path)
os.close(fd)
@contextlib.contextmanager
def copy_from_shm(file: str):
"""
Context manager for copying a file from shared memory to a specified path. It creates a temporary file
in shared memory, yields the path to this temporary file for operations, and then copies the temporary
file to the specified path, cleaning up the temporary file afterwards.
Parameters:
- file (str): The target path for the file to be copied to from shared memory.
Yields:
- str: The path to the temporary file in shared memory.
"""
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
yield tmp_path
shutil.copyfile(tmp_path, file)
finally:
os.remove(tmp_path)
os.close(fd)
def fast_unpickle(path: str) -> Any:
"""
Unpickles and loads an object from a file, optionally using shared memory for faster access.
Parameters:
- path (str): The path to the pickle file to load.
Returns:
- Any: The object loaded from the pickle file.
"""
with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f:
return pickle.load(f)
def fast_pickle(obj: Any, path: str) -> None:
"""
Pickles and saves an object to a file, optionally using shared memory for faster access.
Parameters:
- obj (Any): The object to be pickled and saved.
- path (str): The path where the pickle file should be saved.
"""
with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f:
pickle.dump(obj, f)
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""
Loads tensors from files in parallel using a ThreadPoolExecutor. This function is intended for use with
arrays that have a predefined shape and dtype, loading them from a directory where each tensor is saved
in a separate file.
Parameters:
- shaped_arrays: A sequence of arrays providing the shapes and dtypes of the tensors to load.
- directory (str): The directory from which to load the tensors.
- mesh_config: Configuration for data parallelism across processes or devices.
- tensor_indices (Optional): Specific indices of tensors to load. If None, all tensors are loaded.
Returns:
- List of numpy arrays or zeros depending on the process's role in data parallelism.
"""
pool = ThreadPoolExecutor(max_workers=32)
fs = list()
num_tensors = 0
num_replicas = 1
data_model_shards = math.prod(mesh_config)
if tensor_indices is None:
iterator = enumerate(shaped_arrays)
else:
iterator = zip(tensor_indices, shaped_arrays)
for i, t in iterator:
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
idx = (
jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
+ jax.process_index() % data_model_shards
)
fs.append(
pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
)
num_tensors += 1
else:
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
wait(fs)
return [f.result() for f in fs]
def path_tuple_to_string(path: tuple) -> str:
"""
Converts a tuple representing a path in a nested structure to a string. This function is specifically
used for handling paths within the structure of saved states or parameters, making it easier to
identify and manipulate specific elements.
Parameters:
- path (tuple): A tuple representing a path in a nested structure.
Returns:
- str: A string representation of the path, suitable for logging or identification.
"""
pieces = []
for elem in path:
if isinstance(elem, jax.tree_util.DictKey):
pieces.append(elem.key)
elif isinstance(elem, jax.tree_util.GetAttrKey):
pieces.append(elem.name)
else:
assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
return "/".join(pieces)
def get_load_path_str(
init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
"""
Determines the load path for a given initial path string, applying exclusion and renaming rules. This
function is used in the context of loading states or parameters from files, allowing for flexible
mapping or exclusion based on pattern matching.
Parameters:
- init_path_str (str): The initial path string to process.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths based on pattern matching.
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
Returns:
- Optional[str]: The processed load path string, or None if excluded.
"""
# Exclusion
if load_exclude_rules is not None:
for search_pattern in load_exclude_rules:
if re.search(search_pattern, init_path_str):
return None
# Renaming
load_path_str = init_path_str
if load_rename_rules is not None:
for search_pattern, replacement_pattern in load_rename_rules:
if re.search(search_pattern, load_path_str):
load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
break
return load_path_str
def replace_with_load_state(
init_state: Any,
load_state: Any,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1),
) -> Any:
"""
Replaces elements of an initial state with elements from a loaded state, applying renaming and exclusion
rules. This function supports conditional inclusion and transformation of state elements based on complex
criteria, facilitating flexible state restoration.
Parameters:
- init_state (Any): The initial state before replacement.
- load_state (Any): The state from which to load replacements.
- load_rename_rules (Optional[list[tuple[str, str]]]): Rules for renaming paths.
- load_exclude_rules (Optional[list[str]]): Patterns for paths to exclude from loading.
- mesh_config (tuple): Configuration for data parallelism.
Returns:
- Any: The initial state with elements replaced from the load state.
"""
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
replaced = []
num_replicas = 1
data_model_shards = math.prod(mesh_config)
for i, (init_path, tensor) in enumerate(flatten_init):
init_path_str = path_tuple_to_string(init_path)
load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
if load_path_str is None:
rank_logger.info(f"Excluded from restore: {init_path_str}.")
replaced.append(tensor)
elif load_path_str in load_map:
if load_path_str == init_path_str:
rank_logger.info(f"Restored from ckpt: {init_path_str}.")
else:
rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
replaced.append(load_map[load_path_str])
else:
rank_logger.info(f"Not found in ckpt: {init_path_str}.")
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
replaced.append(tensor)
else:
replaced.append(np.zeros_like(tensor))
return jax.tree_util.tree_unflatten(structure_init, replaced)
def restore(
checkpoint_path: str,
state_shapes: Any,
mesh,
between_hosts_config,
params_only,
state_sharding,
init_state: Optional[Any] = None,
) -> Any:
"""
Restores the state from a checkpoint, optionally focusing on parameters only, and applies sharding
configurations. This function is designed for restoring model states from disk with support for distributed
environments, handling the intricacies of partitioning and host-specific configurations.
Parameters:
- checkpoint_path (str): The path to the checkpoint directory.
- state_shapes (Any): The expected shapes of the state to restore.
- mesh: The mesh configuration for distributed environments.
- between_hosts_config: Configuration for data exchange between hosts.
- params_only (bool): Whether to restore parameters only, excluding other state parts.
- state_sharding: Sharding configuration for the state.
- init_state (Optional[Any]): The initial state to which the checkpoint data is applied.
Returns:
- Any: The restored state, potentially sharded across the distributed environment.
"""
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
ckpt_shapes = state_shapes
ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
# Sanity check to give a better error message.
ckpt_keys = set(state.params.keys())
code_keys = set(state_sharding.params.keys())
if ckpt_keys != code_keys and init_state is None:
missing_in_ckpt = code_keys - ckpt_keys
missing_locally = ckpt_keys - code_keys
raise ValueError(
"Parameters in the code are not matching checkpoint parameters.\n"
"Params missing in checkpoint: {}\nParams missing in code: {}".format(
missing_in_ckpt, missing_locally
)
)
state_sharding = jax.tree_util.tree_map(
lambda x: jax.sharding.PartitionSpec() if x is None else x,
state_sharding,
is_leaf=lambda x: x is None,
)
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
if params_only:
state = state.params
return state