From 1101257c055ebc948ffdb8bf104c25bd41cd0274 Mon Sep 17 00:00:00 2001 From: louiehelm Date: Wed, 20 Mar 2024 08:07:43 +0500 Subject: [PATCH] Allows CPU-based execution --- requirements-cpu.txt | 4 ++++ run.py | 24 +++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 requirements-cpu.txt diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000..ea28d14 --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,4 @@ +dm_haiku==0.0.12 +jax==0.4.25 +numpy==1.26.4 +sentencepiece==0.2.0 diff --git a/run.py b/run.py index f1e157a..d3bc2f6 100644 --- a/run.py +++ b/run.py @@ -12,11 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +import logging, os from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from runners import InferenceRunner, ModelRunner, sample_from_model +# Fall back to using CPU execution if less than 8 GPUs +# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM +# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS +# +# Set True to run model on CPU only +USE_CPU_ONLY = False + +if USE_CPU_ONLY: + # Simulate 8 devices via CPUs + xla_flags = os.environ.get("XLA_FLAGS", "") + xla_flags += " --xla_force_host_platform_device_count=8" + os.environ["XLA_FLAGS"] = xla_flags + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + # Suppress warnings about unused backends + logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend")) + # Suppress false warnings about stuck processes + logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for")) + logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck")) + # Suppress warnings about slow compiling + logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile")) + CKPT_PATH = "./checkpoints/"