Added error handling on run.py

This commit is contained in:
Shephin philip 2024-03-18 11:46:13 +05:30
parent e50578b5f5
commit 256c3468f4

11
run.py
View File

@ -22,6 +22,7 @@ CKPT_PATH = "./checkpoints/"
def main(): def main():
try:
grok_1_model = LanguageModelConfig( grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024, vocab_size=128 * 1024,
pad_token=0, pad_token=0,
@ -47,6 +48,7 @@ def main():
model_axis="model", model_axis="model",
), ),
) )
inference_runner = InferenceRunner( inference_runner = InferenceRunner(
pad_sizes=(1024,), pad_sizes=(1024,),
runner=ModelRunner( runner=ModelRunner(
@ -60,12 +62,21 @@ def main():
local_mesh_config=(1, 8), local_mesh_config=(1, 8),
between_hosts_config=(1, 1), between_hosts_config=(1, 1),
) )
inference_runner.initialize() inference_runner.initialize()
gen = inference_runner.run() gen = inference_runner.run()
inp = "The answer to life the universe and everything is of course" inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
except ValueError as e:
logging.error(f"Value error: {e}")
except Exception as e:
logging.error(f"An unexpected error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)