This commit is contained in:
Michael G. Inso 2024-03-26 17:33:57 +08:00 committed by GitHub
commit d129df04a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

26
run.py
View File

@ -13,15 +13,26 @@
# limitations under the License.
import logging
import hashlib
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/"
CKPT_HASH = "expected_checkpoint_hash"
def validate_checkpoint(path, expected_hash):
calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if calculated_hash != expected_hash:
raise ValueError("Invalid checkpoint file!")
def main():
# Validate checkpoint integrity
validate_checkpoint(CKPT_PATH, CKPT_HASH)
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
@ -47,25 +58,40 @@ def main():
model_axis="model",
),
)
inference_runner = InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
# Limit inference rate
inference_runner.rate_limit = 100
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)
inference_runner.initialize()
gen = inference_runner.run()
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
# Add authentication
@app.route("/inference")
@auth.login_required
def inference():
...
gen = inference_runner.run()
# Rest of inference code
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)