mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 04:29:53 +03:00
Allows CPU-based execution
This commit is contained in:
parent
7050ed204b
commit
1101257c05
4
requirements-cpu.txt
Normal file
4
requirements-cpu.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
dm_haiku==0.0.12
|
||||||
|
jax==0.4.25
|
||||||
|
numpy==1.26.4
|
||||||
|
sentencepiece==0.2.0
|
24
run.py
24
run.py
@ -12,11 +12,33 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging, os
|
||||||
|
|
||||||
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
||||||
from runners import InferenceRunner, ModelRunner, sample_from_model
|
from runners import InferenceRunner, ModelRunner, sample_from_model
|
||||||
|
|
||||||
|
# Fall back to using CPU execution if less than 8 GPUs
|
||||||
|
# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM
|
||||||
|
# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS
|
||||||
|
#
|
||||||
|
# Set True to run model on CPU only
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
# Simulate 8 devices via CPUs
|
||||||
|
xla_flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
xla_flags += " --xla_force_host_platform_device_count=8"
|
||||||
|
os.environ["XLA_FLAGS"] = xla_flags
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
# Suppress warnings about unused backends
|
||||||
|
logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend"))
|
||||||
|
# Suppress false warnings about stuck processes
|
||||||
|
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for"))
|
||||||
|
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck"))
|
||||||
|
# Suppress warnings about slow compiling
|
||||||
|
logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile"))
|
||||||
|
|
||||||
|
|
||||||
CKPT_PATH = "./checkpoints/"
|
CKPT_PATH = "./checkpoints/"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user