mirror of
https://github.com/xai-org/grok-1.git
synced 2025-04-03 18:00:10 +03:00
Merge aa2be03aee
into 7050ed204b
This commit is contained in:
commit
4416453aec
101
checkpoint.py
101
checkpoint.py
@ -39,8 +39,19 @@ rank_logger = logging.getLogger("rank")
|
|||||||
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
||||||
|
|
||||||
|
|
||||||
|
# Utility functions for file handling and shared memory
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def copy_to_shm(file: str):
|
def copy_to_shm(file: str):
|
||||||
|
"""
|
||||||
|
Context manager to copy a file to shared memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The path to the file to be copied.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: The path to the copied file in shared memory.
|
||||||
|
"""
|
||||||
if file.startswith("/dev/shm/"):
|
if file.startswith("/dev/shm/"):
|
||||||
# Nothing to do, the file is already in shared memory.
|
# Nothing to do, the file is already in shared memory.
|
||||||
yield file
|
yield file
|
||||||
@ -58,6 +69,15 @@ def copy_to_shm(file: str):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def copy_from_shm(file: str):
|
def copy_from_shm(file: str):
|
||||||
|
"""
|
||||||
|
Context manager to copy a file from shared memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The path to the file to be copied.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: The path to the temporary file in shared memory.
|
||||||
|
"""
|
||||||
tmp_dir = "/dev/shm/"
|
tmp_dir = "/dev/shm/"
|
||||||
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
|
||||||
try:
|
try:
|
||||||
@ -69,19 +89,48 @@ def copy_from_shm(file: str):
|
|||||||
|
|
||||||
|
|
||||||
def fast_unpickle(path: str) -> Any:
|
def fast_unpickle(path: str) -> Any:
|
||||||
|
"""
|
||||||
|
Unpickle an object from a file using shared memory for faster loading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the file containing the pickled object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The unpickled object.
|
||||||
|
"""
|
||||||
with copy_to_shm(path) as tmp_path:
|
with copy_to_shm(path) as tmp_path:
|
||||||
with open(tmp_path, "rb") as f:
|
with open(tmp_path, "rb") as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
def fast_pickle(obj: Any, path: str) -> None:
|
def fast_pickle(obj: Any, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Pickle an object to a file using shared memory for faster saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Any): The object to be pickled.
|
||||||
|
path (str): The path to the file where the object will be saved.
|
||||||
|
"""
|
||||||
with copy_from_shm(path) as tmp_path:
|
with copy_from_shm(path) as tmp_path:
|
||||||
with open(tmp_path, "wb") as f:
|
with open(tmp_path, "wb") as f:
|
||||||
pickle.dump(obj, f)
|
pickle.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
|
# Tensor loading and path handling
|
||||||
|
|
||||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||||
"""Loads a set of arrays."""
|
"""
|
||||||
|
Load a set of arrays from files in parallel using a thread pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shaped_arrays (list): A list of shaped arrays to be loaded.
|
||||||
|
directory (str): The directory containing the tensor files.
|
||||||
|
mesh_config (tuple): The mesh configuration.
|
||||||
|
tensor_indices (list, optional): The indices of the tensors to load. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of loaded arrays.
|
||||||
|
"""
|
||||||
pool = ThreadPoolExecutor(max_workers=32)
|
pool = ThreadPoolExecutor(max_workers=32)
|
||||||
fs = list()
|
fs = list()
|
||||||
num_tensors = 0
|
num_tensors = 0
|
||||||
@ -108,6 +157,15 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
|
|
||||||
|
|
||||||
def path_tuple_to_string(path: tuple) -> str:
|
def path_tuple_to_string(path: tuple) -> str:
|
||||||
|
"""
|
||||||
|
Convert a path tuple to a string representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (tuple): The path tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The string representation of the path.
|
||||||
|
"""
|
||||||
pieces = []
|
pieces = []
|
||||||
for elem in path:
|
for elem in path:
|
||||||
if isinstance(elem, jax.tree_util.DictKey):
|
if isinstance(elem, jax.tree_util.DictKey):
|
||||||
@ -124,6 +182,17 @@ def get_load_path_str(
|
|||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
load_exclude_rules: Optional[list[str]] = None,
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the load path string based on the initial path string and renaming/exclusion rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_path_str (str): The initial path string.
|
||||||
|
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
|
||||||
|
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The load path string if not excluded, otherwise None.
|
||||||
|
"""
|
||||||
# Exclusion
|
# Exclusion
|
||||||
if load_exclude_rules is not None:
|
if load_exclude_rules is not None:
|
||||||
for search_pattern in load_exclude_rules:
|
for search_pattern in load_exclude_rules:
|
||||||
@ -148,6 +217,19 @@ def replace_with_load_state(
|
|||||||
load_exclude_rules: Optional[list[str]] = None,
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
mesh_config: tuple = (1, 1),
|
mesh_config: tuple = (1, 1),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Replace the initial state with the loaded state based on renaming and exclusion rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_state (Any): The initial state.
|
||||||
|
load_state (Any): The loaded state.
|
||||||
|
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
|
||||||
|
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
|
||||||
|
mesh_config (tuple, optional): The mesh configuration. Defaults to (1, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The replaced state.
|
||||||
|
"""
|
||||||
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
|
||||||
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
|
||||||
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
|
||||||
@ -177,6 +259,8 @@ def replace_with_load_state(
|
|||||||
return jax.tree_util.tree_unflatten(structure_init, replaced)
|
return jax.tree_util.tree_unflatten(structure_init, replaced)
|
||||||
|
|
||||||
|
|
||||||
|
# Checkpoint restoration
|
||||||
|
|
||||||
def restore(
|
def restore(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
state_shapes: Any,
|
state_shapes: Any,
|
||||||
@ -186,6 +270,21 @@ def restore(
|
|||||||
state_sharding,
|
state_sharding,
|
||||||
init_state: Optional[Any] = None,
|
init_state: Optional[Any] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Restore the state from a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path (str): The path to the checkpoint directory.
|
||||||
|
state_shapes (Any): The shapes of the state.
|
||||||
|
mesh: The mesh configuration.
|
||||||
|
between_hosts_config: The configuration for communication between hosts.
|
||||||
|
params_only (bool): Whether to restore only the parameters.
|
||||||
|
state_sharding: The sharding specification for the state.
|
||||||
|
init_state (Optional[Any], optional): The initial state. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The restored state.
|
||||||
|
"""
|
||||||
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
|
||||||
|
|
||||||
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
|
||||||
|
32
run.py
32
run.py
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
||||||
from runners import InferenceRunner, ModelRunner, sample_from_model
|
from runners import InferenceRunner, ModelRunner, sample_from_model
|
||||||
@ -21,8 +22,8 @@ from runners import InferenceRunner, ModelRunner, sample_from_model
|
|||||||
CKPT_PATH = "./checkpoints/"
|
CKPT_PATH = "./checkpoints/"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def create_grok_1_model() -> LanguageModelConfig:
|
||||||
grok_1_model = LanguageModelConfig(
|
return LanguageModelConfig(
|
||||||
vocab_size=128 * 1024,
|
vocab_size=128 * 1024,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
eos_token=2,
|
eos_token=2,
|
||||||
@ -47,24 +48,37 @@ def main():
|
|||||||
model_axis="model",
|
model_axis="model",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
inference_runner = InferenceRunner(
|
|
||||||
|
|
||||||
|
def create_inference_runner(model: LanguageModelConfig, checkpoint_path: str, tokenizer_path: str) -> InferenceRunner:
|
||||||
|
return InferenceRunner(
|
||||||
pad_sizes=(1024,),
|
pad_sizes=(1024,),
|
||||||
runner=ModelRunner(
|
runner=ModelRunner(
|
||||||
model=grok_1_model,
|
model=model,
|
||||||
bs_per_device=0.125,
|
bs_per_device=0.125,
|
||||||
checkpoint_path=CKPT_PATH,
|
checkpoint_path=checkpoint_path,
|
||||||
),
|
),
|
||||||
name="local",
|
name="local",
|
||||||
load=CKPT_PATH,
|
load=checkpoint_path,
|
||||||
tokenizer_path="./tokenizer.model",
|
tokenizer_path=tokenizer_path,
|
||||||
local_mesh_config=(1, 8),
|
local_mesh_config=(1, 8),
|
||||||
between_hosts_config=(1, 1),
|
between_hosts_config=(1, 1),
|
||||||
)
|
)
|
||||||
inference_runner.initialize()
|
|
||||||
|
|
||||||
|
def generate_text(inference_runner: InferenceRunner, prompt: str, max_len: int = 100, temperature: float = 0.01) -> str:
|
||||||
gen = inference_runner.run()
|
gen = inference_runner.run()
|
||||||
|
return sample_from_model(gen, prompt, max_len=max_len, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
grok_1_model = create_grok_1_model()
|
||||||
|
inference_runner = create_inference_runner(grok_1_model, CKPT_PATH, "./tokenizer.model")
|
||||||
|
inference_runner.initialize()
|
||||||
|
|
||||||
inp = "The answer to life the universe and everything is of course"
|
inp = "The answer to life the universe and everything is of course"
|
||||||
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
|
output = generate_text(inference_runner, inp)
|
||||||
|
print(f"Output for prompt: {inp}\n{output}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user