diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..4539aa4 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -104,7 +104,21 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): else: fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype)) wait(fs) - return [f.result() for f in fs] + + # return [f.result() for f in fs] + """ + Improve error reporting in load_tensors by catching exceptions within the futures- + and logging detailed information about the failure. + """ + results = [] + for future in fs: + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Failed to load tensor: {e}") + raise + return results def path_tuple_to_string(path: tuple) -> str: