Added error handling on run.py

This commit is contained in:
Shephin philip 2024-03-18 11:46:13 +05:30
parent e50578b5f5
commit 256c3468f4

95
run.py
View File

@ -22,49 +22,60 @@ CKPT_PATH = "./checkpoints/"
def main(): def main():
grok_1_model = LanguageModelConfig( try:
vocab_size=128 * 1024, grok_1_model = LanguageModelConfig(
pad_token=0, vocab_size=128 * 1024,
eos_token=2, pad_token=0,
sequence_len=8192, eos_token=2,
embedding_init_scale=1.0, sequence_len=8192,
output_multiplier_scale=0.5773502691896257, embedding_init_scale=1.0,
embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257,
model=TransformerConfig( embedding_multiplier_scale=78.38367176906169,
emb_size=48 * 128, model=TransformerConfig(
widening_factor=8, emb_size=48 * 128,
key_size=128, widening_factor=8,
num_q_heads=48, key_size=128,
num_kv_heads=8, num_q_heads=48,
num_layers=64, num_kv_heads=8,
attn_output_multiplier=0.08838834764831845, num_layers=64,
shard_activations=True, attn_output_multiplier=0.08838834764831845,
# MoE. shard_activations=True,
num_experts=8, # MoE.
num_selected_experts=2, num_experts=8,
# Activation sharding. num_selected_experts=2,
data_axis="data", # Activation sharding.
model_axis="model", data_axis="data",
), model_axis="model",
) ),
inference_runner = InferenceRunner( )
pad_sizes=(1024,),
runner=ModelRunner( inference_runner = InferenceRunner(
model=grok_1_model, pad_sizes=(1024,),
bs_per_device=0.125, runner=ModelRunner(
checkpoint_path=CKPT_PATH, model=grok_1_model,
), bs_per_device=0.125,
name="local", checkpoint_path=CKPT_PATH,
load=CKPT_PATH, ),
tokenizer_path="./tokenizer.model", name="local",
local_mesh_config=(1, 8), load=CKPT_PATH,
between_hosts_config=(1, 1), tokenizer_path="./tokenizer.model",
) local_mesh_config=(1, 8),
inference_runner.initialize() between_hosts_config=(1, 1),
gen = inference_runner.run() )
inference_runner.initialize()
gen = inference_runner.run()
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
except ValueError as e:
logging.error(f"Value error: {e}")
except Exception as e:
logging.error(f"An unexpected error occurred: {e}")
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
if __name__ == "__main__": if __name__ == "__main__":