This commit is contained in:
San 2024-04-04 12:18:07 +03:00 committed by GitHub
commit 4416453aec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 123 additions and 10 deletions

View File

@ -39,8 +39,19 @@ rank_logger = logging.getLogger("rank")
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
# Utility functions for file handling and shared memory
@contextlib.contextmanager @contextlib.contextmanager
def copy_to_shm(file: str): def copy_to_shm(file: str):
"""
Context manager to copy a file to shared memory.
Args:
file (str): The path to the file to be copied.
Yields:
str: The path to the copied file in shared memory.
"""
if file.startswith("/dev/shm/"): if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory. # Nothing to do, the file is already in shared memory.
yield file yield file
@ -58,6 +69,15 @@ def copy_to_shm(file: str):
@contextlib.contextmanager @contextlib.contextmanager
def copy_from_shm(file: str): def copy_from_shm(file: str):
"""
Context manager to copy a file from shared memory.
Args:
file (str): The path to the file to be copied.
Yields:
str: The path to the temporary file in shared memory.
"""
tmp_dir = "/dev/shm/" tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try: try:
@ -69,19 +89,48 @@ def copy_from_shm(file: str):
def fast_unpickle(path: str) -> Any: def fast_unpickle(path: str) -> Any:
"""
Unpickle an object from a file using shared memory for faster loading.
Args:
path (str): The path to the file containing the pickled object.
Returns:
Any: The unpickled object.
"""
with copy_to_shm(path) as tmp_path: with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f: with open(tmp_path, "rb") as f:
return pickle.load(f) return pickle.load(f)
def fast_pickle(obj: Any, path: str) -> None: def fast_pickle(obj: Any, path: str) -> None:
"""
Pickle an object to a file using shared memory for faster saving.
Args:
obj (Any): The object to be pickled.
path (str): The path to the file where the object will be saved.
"""
with copy_from_shm(path) as tmp_path: with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f: with open(tmp_path, "wb") as f:
pickle.dump(obj, f) pickle.dump(obj, f)
# Tensor loading and path handling
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays.""" """
Load a set of arrays from files in parallel using a thread pool.
Args:
shaped_arrays (list): A list of shaped arrays to be loaded.
directory (str): The directory containing the tensor files.
mesh_config (tuple): The mesh configuration.
tensor_indices (list, optional): The indices of the tensors to load. Defaults to None.
Returns:
list: A list of loaded arrays.
"""
pool = ThreadPoolExecutor(max_workers=32) pool = ThreadPoolExecutor(max_workers=32)
fs = list() fs = list()
num_tensors = 0 num_tensors = 0
@ -108,6 +157,15 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
def path_tuple_to_string(path: tuple) -> str: def path_tuple_to_string(path: tuple) -> str:
"""
Convert a path tuple to a string representation.
Args:
path (tuple): The path tuple.
Returns:
str: The string representation of the path.
"""
pieces = [] pieces = []
for elem in path: for elem in path:
if isinstance(elem, jax.tree_util.DictKey): if isinstance(elem, jax.tree_util.DictKey):
@ -124,6 +182,17 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None, load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]: ) -> Optional[str]:
"""
Get the load path string based on the initial path string and renaming/exclusion rules.
Args:
init_path_str (str): The initial path string.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
Returns:
Optional[str]: The load path string if not excluded, otherwise None.
"""
# Exclusion # Exclusion
if load_exclude_rules is not None: if load_exclude_rules is not None:
for search_pattern in load_exclude_rules: for search_pattern in load_exclude_rules:
@ -148,6 +217,19 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1), mesh_config: tuple = (1, 1),
) -> Any: ) -> Any:
"""
Replace the initial state with the loaded state based on renaming and exclusion rules.
Args:
init_state (Any): The initial state.
load_state (Any): The loaded state.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
mesh_config (tuple, optional): The mesh configuration. Defaults to (1, 1).
Returns:
Any: The replaced state.
"""
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
@ -177,6 +259,8 @@ def replace_with_load_state(
return jax.tree_util.tree_unflatten(structure_init, replaced) return jax.tree_util.tree_unflatten(structure_init, replaced)
# Checkpoint restoration
def restore( def restore(
checkpoint_path: str, checkpoint_path: str,
state_shapes: Any, state_shapes: Any,
@ -186,6 +270,21 @@ def restore(
state_sharding, state_sharding,
init_state: Optional[Any] = None, init_state: Optional[Any] = None,
) -> Any: ) -> Any:
"""
Restore the state from a checkpoint.
Args:
checkpoint_path (str): The path to the checkpoint directory.
state_shapes (Any): The shapes of the state.
mesh: The mesh configuration.
between_hosts_config: The configuration for communication between hosts.
params_only (bool): Whether to restore only the parameters.
state_sharding: The sharding specification for the state.
init_state (Optional[Any], optional): The initial state. Defaults to None.
Returns:
Any: The restored state.
"""
ckpt_path = os.path.join(checkpoint_path, "ckpt-0") ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
rank_logger.info("Loading checkpoint at {}".format(ckpt_path)) rank_logger.info("Loading checkpoint at {}".format(ckpt_path))

32
run.py
View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model from runners import InferenceRunner, ModelRunner, sample_from_model
@ -21,8 +22,8 @@ from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/" CKPT_PATH = "./checkpoints/"
def main(): def create_grok_1_model() -> LanguageModelConfig:
grok_1_model = LanguageModelConfig( return LanguageModelConfig(
vocab_size=128 * 1024, vocab_size=128 * 1024,
pad_token=0, pad_token=0,
eos_token=2, eos_token=2,
@ -47,24 +48,37 @@ def main():
model_axis="model", model_axis="model",
), ),
) )
inference_runner = InferenceRunner(
def create_inference_runner(model: LanguageModelConfig, checkpoint_path: str, tokenizer_path: str) -> InferenceRunner:
return InferenceRunner(
pad_sizes=(1024,), pad_sizes=(1024,),
runner=ModelRunner( runner=ModelRunner(
model=grok_1_model, model=model,
bs_per_device=0.125, bs_per_device=0.125,
checkpoint_path=CKPT_PATH, checkpoint_path=checkpoint_path,
), ),
name="local", name="local",
load=CKPT_PATH, load=checkpoint_path,
tokenizer_path="./tokenizer.model", tokenizer_path=tokenizer_path,
local_mesh_config=(1, 8), local_mesh_config=(1, 8),
between_hosts_config=(1, 1), between_hosts_config=(1, 1),
) )
inference_runner.initialize()
def generate_text(inference_runner: InferenceRunner, prompt: str, max_len: int = 100, temperature: float = 0.01) -> str:
gen = inference_runner.run() gen = inference_runner.run()
return sample_from_model(gen, prompt, max_len=max_len, temperature=temperature)
def main():
grok_1_model = create_grok_1_model()
inference_runner = create_inference_runner(grok_1_model, CKPT_PATH, "./tokenizer.model")
inference_runner.initialize()
inp = "The answer to life the universe and everything is of course" inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) output = generate_text(inference_runner, inp)
print(f"Output for prompt: {inp}\n{output}")
if __name__ == "__main__": if __name__ == "__main__":