Allows CPU-based execution

This commit is contained in:
louiehelm 2024-03-20 08:07:43 +05:00
parent 7050ed204b
commit 1101257c05
2 changed files with 27 additions and 1 deletions

4
requirements-cpu.txt Normal file
View File

@ -0,0 +1,4 @@
dm_haiku==0.0.12
jax==0.4.25
numpy==1.26.4
sentencepiece==0.2.0

24
run.py
View File

@ -12,11 +12,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging, os
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model from runners import InferenceRunner, ModelRunner, sample_from_model
# Fall back to using CPU execution if less than 8 GPUs
# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM
# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS
#
# Set True to run model on CPU only
USE_CPU_ONLY = False
if USE_CPU_ONLY:
# Simulate 8 devices via CPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = xla_flags
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Suppress warnings about unused backends
logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend"))
# Suppress false warnings about stuck processes
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for"))
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck"))
# Suppress warnings about slow compiling
logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile"))
CKPT_PATH = "./checkpoints/" CKPT_PATH = "./checkpoints/"