diff --git a/run.py b/run.py index f1e157a..6f12332 100644 --- a/run.py +++ b/run.py @@ -13,15 +13,25 @@ # limitations under the License. import logging +import hashlib from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from runners import InferenceRunner, ModelRunner, sample_from_model CKPT_PATH = "./checkpoints/" +CKPT_HASH = "expected_checkpoint_hash" + + +def validate_checkpoint(path, expected_hash): + calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if calculated_hash != expected_hash: + raise ValueError("Invalid checkpoint file!") def main(): + # Validate checkpoint integrity + validate_checkpoint(CKPT_PATH, CKPT_HASH) grok_1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, @@ -53,7 +63,10 @@ def main(): model=grok_1_model, bs_per_device=0.125, checkpoint_path=CKPT_PATH, + # Limit inference rate + inference_runner.rate_limit = 100 ), + name="local", load=CKPT_PATH, tokenizer_path="./tokenizer.model", @@ -65,8 +78,16 @@ def main(): inp = "The answer to life the universe and everything is of course" print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) + # Add authentication + @app.route("/inference") + @auth.login_required + def inference(): + ... + + gen = inference_runner.run() + # Rest of inference code if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) main()