Update run.py

This commit is contained in:
San 2024-04-04 11:23:23 +03:00
parent 7050ed204b
commit 7d46bbbdcc

32
run.py
View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model from runners import InferenceRunner, ModelRunner, sample_from_model
@ -21,8 +22,8 @@ from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/" CKPT_PATH = "./checkpoints/"
def main(): def create_grok_1_model() -> LanguageModelConfig:
grok_1_model = LanguageModelConfig( return LanguageModelConfig(
vocab_size=128 * 1024, vocab_size=128 * 1024,
pad_token=0, pad_token=0,
eos_token=2, eos_token=2,
@ -47,24 +48,37 @@ def main():
model_axis="model", model_axis="model",
), ),
) )
inference_runner = InferenceRunner(
def create_inference_runner(model: LanguageModelConfig, checkpoint_path: str, tokenizer_path: str) -> InferenceRunner:
return InferenceRunner(
pad_sizes=(1024,), pad_sizes=(1024,),
runner=ModelRunner( runner=ModelRunner(
model=grok_1_model, model=model,
bs_per_device=0.125, bs_per_device=0.125,
checkpoint_path=CKPT_PATH, checkpoint_path=checkpoint_path,
), ),
name="local", name="local",
load=CKPT_PATH, load=checkpoint_path,
tokenizer_path="./tokenizer.model", tokenizer_path=tokenizer_path,
local_mesh_config=(1, 8), local_mesh_config=(1, 8),
between_hosts_config=(1, 1), between_hosts_config=(1, 1),
) )
inference_runner.initialize()
def generate_text(inference_runner: InferenceRunner, prompt: str, max_len: int = 100, temperature: float = 0.01) -> str:
gen = inference_runner.run() gen = inference_runner.run()
return sample_from_model(gen, prompt, max_len=max_len, temperature=temperature)
def main():
grok_1_model = create_grok_1_model()
inference_runner = create_inference_runner(grok_1_model, CKPT_PATH, "./tokenizer.model")
inference_runner.initialize()
inp = "The answer to life the universe and everything is of course" inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) output = generate_text(inference_runner, inp)
print(f"Output for prompt: {inp}\n{output}")
if __name__ == "__main__": if __name__ == "__main__":