diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..68677e4 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -83,7 +83,7 @@ def fast_pickle(obj: Any, path: str) -> None: def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): """Loads a set of arrays.""" pool = ThreadPoolExecutor(max_workers=32) - fs = list() + fs = [] num_tensors = 0 num_replicas = 1 data_model_shards = math.prod(mesh_config)