mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 12:39:54 +03:00
Update checkpoint.py
This commit is contained in:
parent
7050ed204b
commit
8aac0cea69
@ -42,7 +42,6 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def copy_to_shm(file: str):
|
def copy_to_shm(file: str):
|
||||||
if file.startswith("/dev/shm/"):
|
if file.startswith("/dev/shm/"):
|
||||||
# Nothing to do, the file is already in shared memory.
|
|
||||||
yield file
|
yield file
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -81,7 +80,6 @@ def fast_pickle(obj: Any, path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||||
"""Loads a set of arrays."""
|
|
||||||
pool = ThreadPoolExecutor(max_workers=32)
|
pool = ThreadPoolExecutor(max_workers=32)
|
||||||
fs = list()
|
fs = list()
|
||||||
num_tensors = 0
|
num_tensors = 0
|
||||||
@ -124,13 +122,11 @@ def get_load_path_str(
|
|||||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||||
load_exclude_rules: Optional[list[str]] = None,
|
load_exclude_rules: Optional[list[str]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
# Exclusion
|
|
||||||
if load_exclude_rules is not None:
|
if load_exclude_rules is not None:
|
||||||
for search_pattern in load_exclude_rules:
|
for search_pattern in load_exclude_rules:
|
||||||
if re.search(search_pattern, init_path_str):
|
if re.search(search_pattern, init_path_str):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Renaming
|
|
||||||
load_path_str = init_path_str
|
load_path_str = init_path_str
|
||||||
if load_rename_rules is not None:
|
if load_rename_rules is not None:
|
||||||
for search_pattern, replacement_pattern in load_rename_rules:
|
for search_pattern, replacement_pattern in load_rename_rules:
|
||||||
@ -197,7 +193,6 @@ def restore(
|
|||||||
|
|
||||||
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
|
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
|
||||||
|
|
||||||
# Sanity check to give a better error message.
|
|
||||||
ckpt_keys = set(state.params.keys())
|
ckpt_keys = set(state.params.keys())
|
||||||
code_keys = set(state_sharding.params.keys())
|
code_keys = set(state_sharding.params.keys())
|
||||||
|
|
||||||
@ -219,3 +214,71 @@ def restore(
|
|||||||
if params_only:
|
if params_only:
|
||||||
state = state.params
|
state = state.params
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
# Database and machine learning integration
|
||||||
|
import sqlite3
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.linear_model import LinearRegression
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def create_database():
|
||||||
|
conn = sqlite3.connect('data_analysis.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('''CREATE TABLE IF NOT EXISTS data (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
latency REAL,
|
||||||
|
packet_loss REAL)''')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def record_data(latency, packet_loss):
|
||||||
|
conn = sqlite3.connect('data_analysis.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('INSERT INTO data (latency, packet_loss) VALUES (?, ?)', (latency, packet_loss))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def train_model():
|
||||||
|
conn = sqlite3.connect('data_analysis.db')
|
||||||
|
data = pd.read_sql_query("SELECT * FROM data", conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
X = data[['latency']]
|
||||||
|
y = data['packet_loss']
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
|
|
||||||
|
model = LinearRegression()
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def analyze_task_startup(latency, model):
|
||||||
|
predicted_packet_loss = model.predict([[latency]])[0]
|
||||||
|
if predicted_packet_loss > 10:
|
||||||
|
print("High packet loss predicted: ", predicted_packet_loss)
|
||||||
|
else:
|
||||||
|
print("Packet loss within acceptable range: ", predicted_packet_loss)
|
||||||
|
|
||||||
|
def join_data_with_external_source():
|
||||||
|
external_data = pd.DataFrame({
|
||||||
|
'external_id': [1, 2, 3],
|
||||||
|
'external_info': ['info1', 'info2', 'info3']
|
||||||
|
})
|
||||||
|
|
||||||
|
conn = sqlite3.connect('data_analysis.db')
|
||||||
|
data = pd.read_sql_query("SELECT * FROM data", conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
joined_data = data.merge(external_data, left_on='id', right_on='external_id')
|
||||||
|
return joined_data
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_database()
|
||||||
|
record_data(50, 5) # Example data
|
||||||
|
record_data(100, 20) # Example data
|
||||||
|
|
||||||
|
model = train_model()
|
||||||
|
analyze_task_startup(70, model)
|
||||||
|
joined_data = join_data_with_external_source()
|
||||||
|
print(joined_data)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user