# Copyright 2024 X.AI Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import contextlib import logging import math import os import pickle import re import shutil import sys import tempfile from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional import jax import numpy as np from jax.experimental import multihost_utils from model import QuantizedWeight8bit logger = logging.getLogger(__name__) rank_logger = logging.getLogger("rank") # Needed for loading the checkpoint with pickle. sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit @contextlib.contextmanager def copy_to_shm(file: str): if file.startswith("/dev/shm/"): yield file return tmp_dir = "/dev/shm/" fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) try: shutil.copyfile(file, tmp_path) yield tmp_path finally: os.remove(tmp_path) os.close(fd) @contextlib.contextmanager def copy_from_shm(file: str): tmp_dir = "/dev/shm/" fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) try: yield tmp_path shutil.copyfile(tmp_path, file) finally: os.remove(tmp_path) os.close(fd) def fast_unpickle(path: str) -> Any: with copy_to_shm(path) as tmp_path: with open(tmp_path, "rb") as f: return pickle.load(f) def fast_pickle(obj: Any, path: str) -> None: with copy_from_shm(path) as tmp_path: with open(tmp_path, "wb") as f: pickle.dump(obj, f) def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): pool = ThreadPoolExecutor(max_workers=32) fs = list() num_tensors = 0 num_replicas = 1 data_model_shards = math.prod(mesh_config) if tensor_indices is None: iterator = enumerate(shaped_arrays) else: iterator = zip(tensor_indices, shaped_arrays) for i, t in iterator: if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas): idx = ( jax.process_index() // (num_replicas * data_model_shards) * data_model_shards + jax.process_index() % data_model_shards ) fs.append( pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}")) ) num_tensors += 1 else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) return [f.result() for f in fs] def path_tuple_to_string(path: tuple) -> str: pieces = [] for elem in path: if isinstance(elem, jax.tree_util.DictKey): pieces.append(elem.key) elif isinstance(elem, jax.tree_util.GetAttrKey): pieces.append(elem.name) else: assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey)) return "/".join(pieces) def get_load_path_str( init_path_str: str, load_rename_rules: Optional[list[tuple[str, str]]] = None, load_exclude_rules: Optional[list[str]] = None, ) -> Optional[str]: if load_exclude_rules is not None: for search_pattern in load_exclude_rules: if re.search(search_pattern, init_path_str): return None load_path_str = init_path_str if load_rename_rules is not None: for search_pattern, replacement_pattern in load_rename_rules: if re.search(search_pattern, load_path_str): load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str) break return load_path_str def replace_with_load_state( init_state: Any, load_state: Any, load_rename_rules: Optional[list[tuple[str, str]]] = None, load_exclude_rules: Optional[list[str]] = None, mesh_config: tuple = (1, 1), ) -> Any: flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} replaced = [] num_replicas = 1 data_model_shards = math.prod(mesh_config) for i, (init_path, tensor) in enumerate(flatten_init): init_path_str = path_tuple_to_string(init_path) load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules) if load_path_str is None: rank_logger.info(f"Excluded from restore: {init_path_str}.") replaced.append(tensor) elif load_path_str in load_map: if load_path_str == init_path_str: rank_logger.info(f"Restored from ckpt: {init_path_str}.") else: rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.") replaced.append(load_map[load_path_str]) else: rank_logger.info(f"Not found in ckpt: {init_path_str}.") if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas): replaced.append(tensor) else: replaced.append(np.zeros_like(tensor)) return jax.tree_util.tree_unflatten(structure_init, replaced) def restore( checkpoint_path: str, state_shapes: Any, mesh, between_hosts_config, params_only, state_sharding, init_state: Optional[Any] = None, ) -> Any: ckpt_path = os.path.join(checkpoint_path, "ckpt-0") rank_logger.info("Loading checkpoint at {}".format(ckpt_path)) ckpt_shapes = state_shapes ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes) ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path] loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config) state = jax.tree_util.tree_unflatten(structure, loaded_tensors) ckpt_keys = set(state.params.keys()) code_keys = set(state_sharding.params.keys()) if ckpt_keys != code_keys and init_state is None: missing_in_ckpt = code_keys - ckpt_keys missing_locally = ckpt_keys - code_keys raise ValueError( "Parameters in the code are not matching checkpoint parameters.\n" "Params missing in checkpoint: {}\nParams missing in code: {}".format( missing_in_ckpt, missing_locally ) ) state_sharding = jax.tree_util.tree_map( lambda x: jax.sharding.PartitionSpec() if x is None else x, state_sharding, is_leaf=lambda x: x is None, ) state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) if params_only: state = state.params return state # Database and machine learning integration import sqlite3 from sklearn.model_selection import train_test_split from sklearn.linear_model import LinearRegression import pandas as pd def create_database(): conn = sqlite3.connect('data_analysis.db') c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS data ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, latency REAL, packet_loss REAL)''') conn.commit() conn.close() def record_data(latency, packet_loss): conn = sqlite3.connect('data_analysis.db') c = conn.cursor() c.execute('INSERT INTO data (latency, packet_loss) VALUES (?, ?)', (latency, packet_loss)) conn.commit() conn.close() def train_model(): conn = sqlite3.connect('data_analysis.db') data = pd.read_sql_query("SELECT * FROM data", conn) conn.close() X = data[['latency']] y = data['packet_loss'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) model = LinearRegression() model.fit(X_train, y_train) return model def analyze_task_startup(latency, model): predicted_packet_loss = model.predict([[latency]])[0] if predicted_packet_loss > 10: print("High packet loss predicted: ", predicted_packet_loss) else: print("Packet loss within acceptable range: ", predicted_packet_loss) def join_data_with_external_source(): external_data = pd.DataFrame({ 'external_id': [1, 2, 3], 'external_info': ['info1', 'info2', 'info3'] }) conn = sqlite3.connect('data_analysis.db') data = pd.read_sql_query("SELECT * FROM data", conn) conn.close() joined_data = data.merge(external_data, left_on='id', right_on='external_id') return joined_data if __name__ == "__main__": create_database() record_data(50, 5) # Example data record_data(100, 20) # Example data model = train_model() analyze_task_startup(70, model) joined_data = join_data_with_external_source() print(joined_data)