This commit is contained in:
Michael G. Inso 2024-03-21 22:09:52 +03:00 committed by GitHub
commit 7a87bc2018
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

23
run.py
View File

@ -13,15 +13,25 @@
# limitations under the License.
import logging
import hashlib
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/"
CKPT_HASH = "expected_checkpoint_hash"
def validate_checkpoint(path, expected_hash):
calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if calculated_hash != expected_hash:
raise ValueError("Invalid checkpoint file!")
def main():
# Validate checkpoint integrity
validate_checkpoint(CKPT_PATH, CKPT_HASH)
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
@ -53,7 +63,10 @@ def main():
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
# Limit inference rate
inference_runner.rate_limit = 100
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
@ -65,8 +78,16 @@ def main():
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
# Add authentication
@app.route("/inference")
@auth.login_required
def inference():
...
gen = inference_runner.run()
# Rest of inference code
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO)
main()