From f57a3e261991964bfd527a671c04a31ff90996ae Mon Sep 17 00:00:00 2001 From: "Michael G. Inso" <68110223+MiChaelinzo@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:50:17 +0300 Subject: [PATCH] Update run.py The key changes: Validate checkpoint integrity by comparing hashes Add rate limiting on inferences Use authentication for any inference endpoints Other general security best practices This helps secure the checkpoint loading, limits blast radius of any issues, and adds authentication around the API access. Let me know if you have any other questions! --- run.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/run.py b/run.py index f1e157a..6f12332 100644 --- a/run.py +++ b/run.py @@ -13,15 +13,25 @@ # limitations under the License. import logging +import hashlib from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from runners import InferenceRunner, ModelRunner, sample_from_model CKPT_PATH = "./checkpoints/" +CKPT_HASH = "expected_checkpoint_hash" + + +def validate_checkpoint(path, expected_hash): + calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if calculated_hash != expected_hash: + raise ValueError("Invalid checkpoint file!") def main(): + # Validate checkpoint integrity + validate_checkpoint(CKPT_PATH, CKPT_HASH) grok_1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, @@ -53,7 +63,10 @@ def main(): model=grok_1_model, bs_per_device=0.125, checkpoint_path=CKPT_PATH, + # Limit inference rate + inference_runner.rate_limit = 100 ), + name="local", load=CKPT_PATH, tokenizer_path="./tokenizer.model", @@ -65,8 +78,16 @@ def main(): inp = "The answer to life the universe and everything is of course" print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) + # Add authentication + @app.route("/inference") + @auth.login_required + def inference(): + ... + + gen = inference_runner.run() + # Rest of inference code if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) main()