re-formatting it to be more readable run.py

It looks like there were some formatting issues in the code. I've taken the liberty of re-formatting it to be more readable.
This commit is contained in:
Michael G. Inso 2024-03-26 12:24:32 +03:00 committed by GitHub
parent f57a3e2619
commit 6ed2d78bea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

103
run.py
View File

@ -1,7 +1,7 @@
# Copyright 2024 X.AI Corp. # Copyright 2024 X.AI Corp.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
@ -30,54 +30,59 @@ def validate_checkpoint(path, expected_hash):
def main(): def main():
# Validate checkpoint integrity # Validate checkpoint integrity
validate_checkpoint(CKPT_PATH, CKPT_HASH) validate_checkpoint(CKPT_PATH, CKPT_HASH)
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024, grok_1_model = LanguageModelConfig(
pad_token=0, vocab_size=128 * 1024,
eos_token=2, pad_token=0,
sequence_len=8192, eos_token=2,
embedding_init_scale=1.0, sequence_len=8192,
output_multiplier_scale=0.5773502691896257, embedding_init_scale=1.0,
embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257,
model=TransformerConfig( embedding_multiplier_scale=78.38367176906169,
emb_size=48 * 128, model=TransformerConfig(
widening_factor=8, emb_size=48 * 128,
key_size=128, widening_factor=8,
num_q_heads=48, key_size=128,
num_kv_heads=8, num_q_heads=48,
num_layers=64, num_kv_heads=8,
attn_output_multiplier=0.08838834764831845, num_layers=64,
shard_activations=True, attn_output_multiplier=0.08838834764831845,
# MoE. shard_activations=True,
num_experts=8, # MoE.
num_selected_experts=2, num_experts=8,
# Activation sharding. num_selected_experts=2,
data_axis="data", # Activation sharding.
model_axis="model", data_axis="data",
), model_axis="model",
) ),
inference_runner = InferenceRunner( )
pad_sizes=(1024,),
runner=ModelRunner( inference_runner = InferenceRunner(
model=grok_1_model, pad_sizes=(1024,),
bs_per_device=0.125, runner=ModelRunner(
checkpoint_path=CKPT_PATH, model=grok_1_model,
# Limit inference rate bs_per_device=0.125,
inference_runner.rate_limit = 100 checkpoint_path=CKPT_PATH,
), # Limit inference rate
inference_runner.rate_limit = 100
name="local", ),
load=CKPT_PATH,
tokenizer_path="./tokenizer.model", name="local",
local_mesh_config=(1, 8), load=CKPT_PATH,
between_hosts_config=(1, 1), tokenizer_path="./tokenizer.model",
) local_mesh_config=(1, 8),
inference_runner.initialize() between_hosts_config=(1, 1),
gen = inference_runner.run() )
inference_runner.initialize()
gen = inference_runner.run()
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
# Add authentication # Add authentication
@app.route("/inference") @app.route("/inference")
@auth.login_required @auth.login_required
@ -87,7 +92,7 @@ def main():
gen = inference_runner.run() gen = inference_runner.run()
# Rest of inference code # Rest of inference code
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
main() main()