From f57a3e261991964bfd527a671c04a31ff90996ae Mon Sep 17 00:00:00 2001 From: "Michael G. Inso" <68110223+MiChaelinzo@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:50:17 +0300 Subject: [PATCH 1/2] Update run.py The key changes: Validate checkpoint integrity by comparing hashes Add rate limiting on inferences Use authentication for any inference endpoints Other general security best practices This helps secure the checkpoint loading, limits blast radius of any issues, and adds authentication around the API access. Let me know if you have any other questions! --- run.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/run.py b/run.py index f1e157a..6f12332 100644 --- a/run.py +++ b/run.py @@ -13,15 +13,25 @@ # limitations under the License. import logging +import hashlib from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from runners import InferenceRunner, ModelRunner, sample_from_model CKPT_PATH = "./checkpoints/" +CKPT_HASH = "expected_checkpoint_hash" + + +def validate_checkpoint(path, expected_hash): + calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if calculated_hash != expected_hash: + raise ValueError("Invalid checkpoint file!") def main(): + # Validate checkpoint integrity + validate_checkpoint(CKPT_PATH, CKPT_HASH) grok_1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, @@ -53,7 +63,10 @@ def main(): model=grok_1_model, bs_per_device=0.125, checkpoint_path=CKPT_PATH, + # Limit inference rate + inference_runner.rate_limit = 100 ), + name="local", load=CKPT_PATH, tokenizer_path="./tokenizer.model", @@ -65,8 +78,16 @@ def main(): inp = "The answer to life the universe and everything is of course" print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) + # Add authentication + @app.route("/inference") + @auth.login_required + def inference(): + ... + + gen = inference_runner.run() + # Rest of inference code if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) main() From 6ed2d78beab3a773bb1b9531a05930f3a0c84bbd Mon Sep 17 00:00:00 2001 From: "Michael G. Inso" <68110223+MiChaelinzo@users.noreply.github.com> Date: Tue, 26 Mar 2024 12:24:32 +0300 Subject: [PATCH 2/2] re-formatting it to be more readable run.py It looks like there were some formatting issues in the code. I've taken the liberty of re-formatting it to be more readable. --- run.py | 103 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/run.py b/run.py index 6f12332..e6107ee 100644 --- a/run.py +++ b/run.py @@ -1,7 +1,7 @@ # Copyright 2024 X.AI Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -30,54 +30,59 @@ def validate_checkpoint(path, expected_hash): def main(): - # Validate checkpoint integrity + # Validate checkpoint integrity validate_checkpoint(CKPT_PATH, CKPT_HASH) - grok_1_model = LanguageModelConfig( - vocab_size=128 * 1024, - pad_token=0, - eos_token=2, - sequence_len=8192, - embedding_init_scale=1.0, - output_multiplier_scale=0.5773502691896257, - embedding_multiplier_scale=78.38367176906169, - model=TransformerConfig( - emb_size=48 * 128, - widening_factor=8, - key_size=128, - num_q_heads=48, - num_kv_heads=8, - num_layers=64, - attn_output_multiplier=0.08838834764831845, - shard_activations=True, - # MoE. - num_experts=8, - num_selected_experts=2, - # Activation sharding. - data_axis="data", - model_axis="model", - ), - ) - inference_runner = InferenceRunner( - pad_sizes=(1024,), - runner=ModelRunner( - model=grok_1_model, - bs_per_device=0.125, - checkpoint_path=CKPT_PATH, - # Limit inference rate - inference_runner.rate_limit = 100 - ), - - name="local", - load=CKPT_PATH, - tokenizer_path="./tokenizer.model", - local_mesh_config=(1, 8), - between_hosts_config=(1, 1), - ) - inference_runner.initialize() - gen = inference_runner.run() + + grok_1_model = LanguageModelConfig( + vocab_size=128 * 1024, + pad_token=0, + eos_token=2, + sequence_len=8192, + embedding_init_scale=1.0, + output_multiplier_scale=0.5773502691896257, + embedding_multiplier_scale=78.38367176906169, + model=TransformerConfig( + emb_size=48 * 128, + widening_factor=8, + key_size=128, + num_q_heads=48, + num_kv_heads=8, + num_layers=64, + attn_output_multiplier=0.08838834764831845, + shard_activations=True, + # MoE. + num_experts=8, + num_selected_experts=2, + # Activation sharding. + data_axis="data", + model_axis="model", + ), + ) + + inference_runner = InferenceRunner( + pad_sizes=(1024,), + runner=ModelRunner( + model=grok_1_model, + bs_per_device=0.125, + checkpoint_path=CKPT_PATH, + # Limit inference rate + inference_runner.rate_limit = 100 + ), + + name="local", + load=CKPT_PATH, + tokenizer_path="./tokenizer.model", + local_mesh_config=(1, 8), + between_hosts_config=(1, 1), + ) + + inference_runner.initialize() + + gen = inference_runner.run() + + inp = "The answer to life the universe and everything is of course" + print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) - inp = "The answer to life the universe and everything is of course" - print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) # Add authentication @app.route("/inference") @auth.login_required @@ -87,7 +92,7 @@ def main(): gen = inference_runner.run() # Rest of inference code - + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - main() + main()