diff --git a/run.py b/run.py index 6f12332..e6107ee 100644 --- a/run.py +++ b/run.py @@ -1,7 +1,7 @@ # Copyright 2024 X.AI Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -30,54 +30,59 @@ def validate_checkpoint(path, expected_hash): def main(): - # Validate checkpoint integrity + # Validate checkpoint integrity validate_checkpoint(CKPT_PATH, CKPT_HASH) - grok_1_model = LanguageModelConfig( - vocab_size=128 * 1024, - pad_token=0, - eos_token=2, - sequence_len=8192, - embedding_init_scale=1.0, - output_multiplier_scale=0.5773502691896257, - embedding_multiplier_scale=78.38367176906169, - model=TransformerConfig( - emb_size=48 * 128, - widening_factor=8, - key_size=128, - num_q_heads=48, - num_kv_heads=8, - num_layers=64, - attn_output_multiplier=0.08838834764831845, - shard_activations=True, - # MoE. - num_experts=8, - num_selected_experts=2, - # Activation sharding. - data_axis="data", - model_axis="model", - ), - ) - inference_runner = InferenceRunner( - pad_sizes=(1024,), - runner=ModelRunner( - model=grok_1_model, - bs_per_device=0.125, - checkpoint_path=CKPT_PATH, - # Limit inference rate - inference_runner.rate_limit = 100 - ), - - name="local", - load=CKPT_PATH, - tokenizer_path="./tokenizer.model", - local_mesh_config=(1, 8), - between_hosts_config=(1, 1), - ) - inference_runner.initialize() - gen = inference_runner.run() + + grok_1_model = LanguageModelConfig( + vocab_size=128 * 1024, + pad_token=0, + eos_token=2, + sequence_len=8192, + embedding_init_scale=1.0, + output_multiplier_scale=0.5773502691896257, + embedding_multiplier_scale=78.38367176906169, + model=TransformerConfig( + emb_size=48 * 128, + widening_factor=8, + key_size=128, + num_q_heads=48, + num_kv_heads=8, + num_layers=64, + attn_output_multiplier=0.08838834764831845, + shard_activations=True, + # MoE. + num_experts=8, + num_selected_experts=2, + # Activation sharding. + data_axis="data", + model_axis="model", + ), + ) + + inference_runner = InferenceRunner( + pad_sizes=(1024,), + runner=ModelRunner( + model=grok_1_model, + bs_per_device=0.125, + checkpoint_path=CKPT_PATH, + # Limit inference rate + inference_runner.rate_limit = 100 + ), + + name="local", + load=CKPT_PATH, + tokenizer_path="./tokenizer.model", + local_mesh_config=(1, 8), + between_hosts_config=(1, 1), + ) + + inference_runner.initialize() + + gen = inference_runner.run() + + inp = "The answer to life the universe and everything is of course" + print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) - inp = "The answer to life the universe and everything is of course" - print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) # Add authentication @app.route("/inference") @auth.login_required @@ -87,7 +92,7 @@ def main(): gen = inference_runner.run() # Rest of inference code - + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - main() + main()