-
Notifications
You must be signed in to change notification settings - Fork 7
fix: Gracefully shutdown multiprocessing.Manager and fix proxy resource leaks #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
56ba8ea
fcb41de
363835b
2319060
23d6cad
c630a8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -265,6 +265,21 @@ def remove_checkpoint(self, path: _PATH) -> None: | |
| else: | ||
| self.fallback_checkpoint_io.remove_checkpoint(path) | ||
|
|
||
| @log_execution_time(logger=_LOGGER, name="MLFlashpointCheckpointIO.teardown") | ||
| def teardown(self) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is this invoked btw?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inside the teardown of MLFlashpointAsyncFinalizableCheckpointIO |
||
| """Tears down the CheckpointIO instance and its strategies/fallbacks.""" | ||
| if hasattr(super(), "teardown"): | ||
| super().teardown() | ||
|
|
||
| if hasattr(self, "save_strategy") and self.save_strategy and hasattr(self.save_strategy, "teardown"): | ||
| self.save_strategy.teardown() | ||
| if ( | ||
| hasattr(self, "fallback_checkpoint_io") | ||
| and self.fallback_checkpoint_io | ||
| and hasattr(self.fallback_checkpoint_io, "teardown") | ||
| ): | ||
| self.fallback_checkpoint_io.teardown() | ||
|
|
||
|
|
||
| class MLFlashpointAsyncFinalizableCheckpointIO(AsyncFinalizableCheckpointIO): | ||
| """CheckpointIO wrapper for async checkpoint saving and synchronous finalization | ||
|
|
@@ -391,8 +406,16 @@ def maybe_finalize_save_checkpoint(self, blocking: bool = False) -> bool: | |
| @override | ||
| @log_execution_time(logger=_LOGGER, name="MLFlashpointAsyncFinalizeCheckpointIO.teardown") | ||
| def teardown(self) -> None: | ||
| """Warns if there are any pending checkpoint saves.""" | ||
| """Warns if there are any pending checkpoint saves and cleans up resources.""" | ||
| super().teardown() | ||
|
|
||
| if ( | ||
| hasattr(self, "mlf_checkpoint_io") | ||
| and self.mlf_checkpoint_io | ||
| and hasattr(self.mlf_checkpoint_io, "teardown") | ||
| ): | ||
| self.mlf_checkpoint_io.teardown() | ||
|
|
||
| if ( | ||
| self._mlf_async_calls_queue.get_num_unfinalized_calls() | ||
| + self._alt_async_calls_queue.get_num_unfinalized_calls() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -388,6 +388,21 @@ def finish_checkpoint( | |
| checkpoint_id, | ||
| ) | ||
| self._write_results_per_checkpoint_id.pop(checkpoint_id, None) | ||
| self._write_events_per_checkpoint_id.pop(checkpoint_id, None) | ||
|
|
||
| @log_execution_time(logger=_LOGGER, name="teardown", level=logging.INFO) | ||
| def teardown(self) -> None: | ||
| """Tears down the StorageWriter, including shutting down the torch_mp Manager.""" | ||
| if self._main_process_torchmp_manager_future is not None: | ||
| try: | ||
| manager = self._main_process_torchmp_manager_future.result(timeout=1.0) | ||
| _LOGGER.info("Shutting down torch_mp Manager...") | ||
| manager.shutdown() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the call above times out, this will fail right? what value will
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it timeout, will get timeout exception and caught by line 403, then the manager.shutdown() would be skipped |
||
| _LOGGER.info("Successfully shut down torch_mp Manager.") | ||
| except Exception as e: | ||
| _LOGGER.warning("Failed to shutdown torch_mp Manager: %s", e) | ||
| finally: | ||
| self._main_process_torchmp_manager_future = None | ||
|
|
||
| @classmethod | ||
| @override | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For improved readability, consider using the
old_storage_writervariable when creating the newMemoryStorageWriterinstance. While the current code is functionally correct becauseself._storage_writerstill refers to the old instance at that point, explicitly usingold_storage_writermakes the intent clearer and less prone to misinterpretation during future maintenance.