From defd415f9a96da34ab4e2e8b4a3ec669b0065561 Mon Sep 17 00:00:00 2001 From: Aareon Sullivan Date: Fri, 22 Mar 2024 19:22:04 -0500 Subject: [PATCH] Improve tempfile handling in checkpoint.py Using more context managers, as well as dynamic temp dir creation, improve temp file handling, error handling, and logging --- checkpoint.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..ef57a8b 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -46,26 +46,35 @@ def copy_to_shm(file: str): yield file return - tmp_dir = "/dev/shm/" - fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) - try: - shutil.copyfile(file, tmp_path) - yield tmp_path - finally: - os.remove(tmp_path) - os.close(fd) + with tempfile.NamedTemporaryFile(dir="/dev/shm", delete=False) as tmp_file: + tmp_path = tmp_file.name + try: + shutil.copyfile(file, tmp_path) + yield tmp_path + finally: + try: + os.remove(tmp_path) + except OSError as e: + # Handle file deletion error gracefully + logger.error(f"Error deleting temporary file: {e}") + raise @contextlib.contextmanager def copy_from_shm(file: str): tmp_dir = "/dev/shm/" - fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) - try: - yield tmp_path - shutil.copyfile(tmp_path, file) - finally: - os.remove(tmp_path) - os.close(fd) + with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp_file: + tmp_path = tmp_file.name + try: + yield tmp_path + shutil.copyfile(tmp_path, file) + finally: + try: + os.remove(tmp_path) + except OSError as e: + # Handle file deletion error gracefully + logger.error(f"Error deleting temporary file: {e}") + raise def fast_unpickle(path: str) -> Any: