Update run.py

This commit is contained in:
San 2024-04-04 11:23:23 +03:00
parent 7050ed204b
commit 7d46bbbdcc

32
run.py
View File

@ -13,6 +13,7 @@
# limitations under the License.
import logging
from typing import Optional
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
@ -21,8 +22,8 @@ from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints/"
def main():
grok_1_model = LanguageModelConfig(
def create_grok_1_model() -> LanguageModelConfig:
return LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
eos_token=2,
@ -47,24 +48,37 @@ def main():
model_axis="model",
),
)
inference_runner = InferenceRunner(
def create_inference_runner(model: LanguageModelConfig, checkpoint_path: str, tokenizer_path: str) -> InferenceRunner:
return InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
model=model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
checkpoint_path=checkpoint_path,
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
load=checkpoint_path,
tokenizer_path=tokenizer_path,
local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)
inference_runner.initialize()
def generate_text(inference_runner: InferenceRunner, prompt: str, max_len: int = 100, temperature: float = 0.01) -> str:
gen = inference_runner.run()
return sample_from_model(gen, prompt, max_len=max_len, temperature=temperature)
def main():
grok_1_model = create_grok_1_model()
inference_runner = create_inference_runner(grok_1_model, CKPT_PATH, "./tokenizer.model")
inference_runner.initialize()
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
output = generate_text(inference_runner, inp)
print(f"Output for prompt: {inp}\n{output}")
if __name__ == "__main__":