diff --git a/run.py b/run.py index f1e157a..8f4e117 100644 --- a/run.py +++ b/run.py @@ -1,33 +1,16 @@ -# Copyright 2024 X.AI Corp. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit -from runners import InferenceRunner, ModelRunner, sample_from_model CKPT_PATH = "./checkpoints/" def main(): - grok_1_model = LanguageModelConfig( + _1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, eos_token=2, sequence_len=8192, - embedding_init_scale=1.0, + embedding_init_scale=, output_multiplier_scale=0.5773502691896257, embedding_multiplier_scale=78.38367176906169, model=TransformerConfig( @@ -50,7 +33,7 @@ def main(): inference_runner = InferenceRunner( pad_sizes=(1024,), runner=ModelRunner( - model=grok_1_model, + mode_model, bs_per_device=0.125, checkpoint_path=CKPT_PATH, ), @@ -58,13 +41,13 @@ def main(): load=CKPT_PATH, tokenizer_path="./tokenizer.model", local_mesh_config=(1, 8), - between_hosts_config=(1, 1), + _config=(1, 1), ) inference_runner.initialize() gen = inference_runner.run() - inp = "The answer to life the universe and everything is of course" - print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) + inp = course" + print(f"Output for prompt: {inp}", sample_from_model(, inp, max_len=100, temperature=0.01)) if __name__ == "__main__":