2024-03-15 01:03:58 +03:00
|
|
|
# Copyright 2024 X.AI Corp.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import logging
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import re
|
|
|
|
import shutil
|
|
|
|
import sys
|
|
|
|
import tempfile
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
2024-03-19 19:36:40 +03:00
|
|
|
# For get_load_path_str
|
|
|
|
# A simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
|
|
|
from functools import lru_cache
|
|
|
|
|
2024-03-15 01:03:58 +03:00
|
|
|
import jax
|
|
|
|
import numpy as np
|
|
|
|
from jax.experimental import multihost_utils
|
|
|
|
|
|
|
|
from model import QuantizedWeight8bit
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
rank_logger = logging.getLogger("rank")
|
|
|
|
|
|
|
|
# Needed for loading the checkpoint with pickle.
|
|
|
|
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def copy_to_shm(file: str):
|
|
|
|
if file.startswith("/dev/shm/"):
|
|
|
|
# Nothing to do, the file is already in shared memory.
|
|
|
|
yield file
|
|
|
|
return
|
|
|
|
|
|
|
|
tmp_dir = "/dev/shm/"
|
|
|
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
|
|
|
try:
|
|
|
|
shutil.copyfile(file, tmp_path)
|
|
|
|
yield tmp_path
|
|
|
|
finally:
|
|
|
|
os.remove(tmp_path)
|
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def copy_from_shm(file: str):
|
|
|
|
tmp_dir = "/dev/shm/"
|
|
|
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
|
|
|
try:
|
|
|
|
yield tmp_path
|
|
|
|
shutil.copyfile(tmp_path, file)
|
|
|
|
finally:
|
|
|
|
os.remove(tmp_path)
|
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
|
|
|
def fast_unpickle(path: str) -> Any:
|
|
|
|
with copy_to_shm(path) as tmp_path:
|
|
|
|
with open(tmp_path, "rb") as f:
|
|
|
|
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
def fast_pickle(obj: Any, path: str) -> None:
|
|
|
|
with copy_from_shm(path) as tmp_path:
|
|
|
|
with open(tmp_path, "wb") as f:
|
|
|
|
pickle.dump(obj, f)
|
|
|
|
|
|
|
|
|
|
|
|
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|
|
|
"""Loads a set of arrays."""
|
|
|
|
pool = ThreadPoolExecutor(max_workers=32)
|
|
|
|
fs = list()
|
|
|
|
num_tensors = 0
|
|
|
|
num_replicas = 1
|
|
|
|
data_model_shards = math.prod(mesh_config)
|
|
|
|
if tensor_indices is None:
|
|
|
|
iterator = enumerate(shaped_arrays)
|
|
|
|
else:
|
|
|
|
iterator = zip(tensor_indices, shaped_arrays)
|
|
|
|
for i, t in iterator:
|
|
|
|
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
|
|
|
|
idx = (
|
|
|
|
jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
|
|
|
|
+ jax.process_index() % data_model_shards
|
|
|
|
)
|
|
|
|
fs.append(
|
|
|
|
pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
|
|
|
|
)
|
|
|
|
num_tensors += 1
|
|
|
|
else:
|
|
|
|
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
|
|
|
|
wait(fs)
|
2024-03-19 19:28:37 +03:00
|
|
|
|
|
|
|
# return [f.result() for f in fs]
|
|
|
|
"""
|
|
|
|
Improve error reporting in load_tensors by catching exceptions within the futures-
|
|
|
|
and logging detailed information about the failure.
|
|
|
|
"""
|
|
|
|
results = []
|
|
|
|
for future in fs:
|
|
|
|
try:
|
|
|
|
result = future.result()
|
|
|
|
results.append(result)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to load tensor: {e}")
|
|
|
|
raise
|
|
|
|
return results
|
2024-03-15 01:03:58 +03:00
|
|
|
|
|
|
|
|
|
|
|
def path_tuple_to_string(path: tuple) -> str:
|
|
|
|
pieces = []
|
|
|
|
for elem in path:
|
|
|
|
if isinstance(elem, jax.tree_util.DictKey):
|
|
|
|
pieces.append(elem.key)
|
|
|
|
elif isinstance(elem, jax.tree_util.GetAttrKey):
|
|
|
|
pieces.append(elem.name)
|
|
|
|
else:
|
|
|
|
assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
|
|
|
|
return "/".join(pieces)
|
|
|
|
|
|
|
|
|
2024-03-19 19:36:40 +03:00
|
|
|
"""
|
|
|
|
For get_load_path_str(),
|
|
|
|
introducing a simple caching mechanism to avoid recomputing regex matches for paths that have already been processed.
|
|
|
|
"""
|
|
|
|
@lru_cache(maxsize=None)
|
|
|
|
def get_load_path_str_cached(
|
|
|
|
init_path_str: str,
|
|
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
|
|
load_exclude_rules: Optional[list[str]] = None,
|
|
|
|
) -> Optional[str]:
|
|
|
|
return get_load_path_str(
|
|
|
|
init_path_str,
|
|
|
|
load_rename_rules,
|
|
|
|
load_exclude_rules
|
|
|
|
)
|
|
|
|
|
2024-03-15 01:03:58 +03:00
|
|
|
def get_load_path_str(
|
|
|
|
init_path_str: str,
|
|
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
|
|
load_exclude_rules: Optional[list[str]] = None,
|
|
|
|
) -> Optional[str]:
|
|
|
|
# Exclusion
|
|
|
|
if load_exclude_rules is not None:
|
|
|
|
for search_pattern in load_exclude_rules:
|
|
|
|
if re.search(search_pattern, init_path_str):
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Renaming
|
|
|
|
load_path_str = init_path_str
|
|
|
|
if load_rename_rules is not None:
|
|
|
|
for search_pattern, replacement_pattern in load_rename_rules:
|
|
|
|
if re.search(search_pattern, load_path_str):
|
|
|
|
load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
|
|
|
|
break
|
|
|
|
|
|
|
|
return load_path_str
|
|
|
|
|
|
|
|
|
|
|
|
def replace_with_load_state(
|
|
|
|
init_state: Any,
|
|
|
|
load_state: Any,
|
|
|
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
|
|
|
load_exclude_rules: Optional[list[str]] = None,
|
|
|
|
mesh_config: tuple = (1, 1),
|
|
|
|
) -> Any:
|
|
|
|
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
|
|
|
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
|
|
|
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
|
|
|
|
|
|
|
replaced = []
|
|
|
|
num_replicas = 1
|
|
|
|
data_model_shards = math.prod(mesh_config)
|
|
|
|
for i, (init_path, tensor) in enumerate(flatten_init):
|
|
|
|
init_path_str = path_tuple_to_string(init_path)
|
2024-03-19 19:36:40 +03:00
|
|
|
load_path_str = get_load_path_str_cached(init_path_str, load_rename_rules, load_exclude_rules)
|
2024-03-15 01:03:58 +03:00
|
|
|
if load_path_str is None:
|
|
|
|
rank_logger.info(f"Excluded from restore: {init_path_str}.")
|
|
|
|
replaced.append(tensor)
|
|
|
|
elif load_path_str in load_map:
|
|
|
|
if load_path_str == init_path_str:
|
|
|
|
rank_logger.info(f"Restored from ckpt: {init_path_str}.")
|
|
|
|
else:
|
|
|
|
rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
|
|
|
|
replaced.append(load_map[load_path_str])
|
|
|
|
else:
|
|
|
|
rank_logger.info(f"Not found in ckpt: {init_path_str}.")
|
|
|
|
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
|
|
|
|
replaced.append(tensor)
|
|
|
|
else:
|
|
|
|
replaced.append(np.zeros_like(tensor))
|
|
|
|
|
|
|
|
return jax.tree_util.tree_unflatten(structure_init, replaced)
|
|
|
|
|
|
|
|
|
|
|
|
def restore(
|
|
|
|
checkpoint_path: str,
|
|
|
|
state_shapes: Any,
|
|
|
|
mesh,
|
|
|
|
between_hosts_config,
|
|
|
|
params_only,
|
|
|
|
state_sharding,
|
|
|
|
init_state: Optional[Any] = None,
|
|
|
|
) -> Any:
|
|
|
|
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
|
|
|
|
|
|
|
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
|
|
|
ckpt_shapes = state_shapes
|
|
|
|
ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
|
|
|
|
|
|
|
|
ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
|
|
|
|
loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
|
|
|
|
|
|
|
|
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
|
|
|
|
|
|
|
|
# Sanity check to give a better error message.
|
|
|
|
ckpt_keys = set(state.params.keys())
|
|
|
|
code_keys = set(state_sharding.params.keys())
|
|
|
|
|
|
|
|
if ckpt_keys != code_keys and init_state is None:
|
|
|
|
missing_in_ckpt = code_keys - ckpt_keys
|
|
|
|
missing_locally = ckpt_keys - code_keys
|
|
|
|
raise ValueError(
|
|
|
|
"Parameters in the code are not matching checkpoint parameters.\n"
|
|
|
|
"Params missing in checkpoint: {}\nParams missing in code: {}".format(
|
|
|
|
missing_in_ckpt, missing_locally
|
|
|
|
)
|
|
|
|
)
|
|
|
|
state_sharding = jax.tree_util.tree_map(
|
|
|
|
lambda x: jax.sharding.PartitionSpec() if x is None else x,
|
|
|
|
state_sharding,
|
|
|
|
is_leaf=lambda x: x is None,
|
|
|
|
)
|
|
|
|
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
|
|
|
|
if params_only:
|
|
|
|
state = state.params
|
|
|
|
return state
|