Update checkpoint.py

This commit is contained in:
Roy SALIBA 2024-05-29 14:00:16 +02:00 committed by GitHub
parent 7050ed204b
commit 8aac0cea69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,7 +42,6 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
@contextlib.contextmanager @contextlib.contextmanager
def copy_to_shm(file: str): def copy_to_shm(file: str):
if file.startswith("/dev/shm/"): if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file yield file
return return
@ -81,7 +80,6 @@ def fast_pickle(obj: Any, path: str) -> None:
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays."""
pool = ThreadPoolExecutor(max_workers=32) pool = ThreadPoolExecutor(max_workers=32)
fs = list() fs = list()
num_tensors = 0 num_tensors = 0
@ -124,13 +122,11 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None, load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None, load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]: ) -> Optional[str]:
# Exclusion
if load_exclude_rules is not None: if load_exclude_rules is not None:
for search_pattern in load_exclude_rules: for search_pattern in load_exclude_rules:
if re.search(search_pattern, init_path_str): if re.search(search_pattern, init_path_str):
return None return None
# Renaming
load_path_str = init_path_str load_path_str = init_path_str
if load_rename_rules is not None: if load_rename_rules is not None:
for search_pattern, replacement_pattern in load_rename_rules: for search_pattern, replacement_pattern in load_rename_rules:
@ -197,7 +193,6 @@ def restore(
state = jax.tree_util.tree_unflatten(structure, loaded_tensors) state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
# Sanity check to give a better error message.
ckpt_keys = set(state.params.keys()) ckpt_keys = set(state.params.keys())
code_keys = set(state_sharding.params.keys()) code_keys = set(state_sharding.params.keys())
@ -219,3 +214,71 @@ def restore(
if params_only: if params_only:
state = state.params state = state.params
return state return state
# Database and machine learning integration
import sqlite3
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import pandas as pd
def create_database():
conn = sqlite3.connect('data_analysis.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
latency REAL,
packet_loss REAL)''')
conn.commit()
conn.close()
def record_data(latency, packet_loss):
conn = sqlite3.connect('data_analysis.db')
c = conn.cursor()
c.execute('INSERT INTO data (latency, packet_loss) VALUES (?, ?)', (latency, packet_loss))
conn.commit()
conn.close()
def train_model():
conn = sqlite3.connect('data_analysis.db')
data = pd.read_sql_query("SELECT * FROM data", conn)
conn.close()
X = data[['latency']]
y = data['packet_loss']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LinearRegression()
model.fit(X_train, y_train)
return model
def analyze_task_startup(latency, model):
predicted_packet_loss = model.predict([[latency]])[0]
if predicted_packet_loss > 10:
print("High packet loss predicted: ", predicted_packet_loss)
else:
print("Packet loss within acceptable range: ", predicted_packet_loss)
def join_data_with_external_source():
external_data = pd.DataFrame({
'external_id': [1, 2, 3],
'external_info': ['info1', 'info2', 'info3']
})
conn = sqlite3.connect('data_analysis.db')
data = pd.read_sql_query("SELECT * FROM data", conn)
conn.close()
joined_data = data.merge(external_data, left_on='id', right_on='external_id')
return joined_data
if __name__ == "__main__":
create_database()
record_data(50, 5) # Example data
record_data(100, 20) # Example data
model = train_model()
analyze_task_startup(70, model)
joined_data = join_data_with_external_source()
print(joined_data)