diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index cb48d156acf..9e53cad43e6 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. from urllib.parse import urlparse +from urllib.request import url2pathname import mlflow from filelock import FileLock from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.entities import ViewType -import os from typing import Optional, Text from pathlib import Path @@ -19,6 +19,14 @@ logger = get_module_logger("workflow") +def _file_uri_to_path(uri: str) -> Path: + pr = urlparse(uri) + path = url2pathname(pr.path) + if pr.netloc and pr.netloc != "localhost": + path = f"//{pr.netloc}{path}" + return Path(path) + + class ExpManager: """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. @@ -233,7 +241,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec # So we supported it in the interface wrapper pr = urlparse(self.uri) if pr.scheme == "file": - with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110 + with FileLock(_file_uri_to_path(self.uri) / "filelock"): # pylint: disable=E0110 return self.create_exp(experiment_name), True # NOTE: for other schemes like http, we double check to avoid create exp conflicts try: diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5fd99c0769f..668eb49d8ba 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -359,23 +359,38 @@ def start_run(self): ) # Log necessary environment variables return run - def _log_uncommitted_code(self): - """ - Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes. - So this tries to automatically to log them all. - """ - # TODO: the sub-directories maybe git repos. - # So it will be better if we can walk the sub-directories and log the uncommitted changes. - for cmd, fname in [ - ("git diff", "code_diff.txt"), - ("git status", "code_status.txt"), - ("git diff --cached", "code_cached.txt"), - ]: - try: - out = subprocess.check_output(cmd, shell=True) - self.client.log_text(self.id, out.decode(), fname) # this behaves same as above - except subprocess.CalledProcessError: - logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.") + def _log_uncommitted_code(self): + """ + Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes. + So this tries to automatically to log them all. + """ + try: + result = subprocess.run( + ["git", "rev-parse", "--is-inside-work-tree"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + return + + if result.stdout.decode().strip().lower() != "true": + return + + # TODO: the sub-directories maybe git repos. + # So it will be better if we can walk the sub-directories and log the uncommitted changes. + for cmd, fname in [ + (["git", "diff"], "code_diff.txt"), + (["git", "status"], "code_status.txt"), + (["git", "diff", "--cached"], "code_cached.txt"), + ]: + try: + out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL) + self.client.log_text(self.id, out.decode(), fname) # this behaves same as above + except (subprocess.CalledProcessError, FileNotFoundError): + logger.debug( + f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}." + ) def end_run(self, status: str = Recorder.STATUS_S): assert status in [ diff --git a/tests/test_workflow_cwd_safety.py b/tests/test_workflow_cwd_safety.py new file mode 100644 index 00000000000..decc2adc696 --- /dev/null +++ b/tests/test_workflow_cwd_safety.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import subprocess +from pathlib import Path +from unittest import mock + +from qlib.workflow.expm import _file_uri_to_path +from qlib.workflow.recorder import MLflowRecorder + + +def _make_recorder() -> MLflowRecorder: + recorder = MLflowRecorder.__new__(MLflowRecorder) + recorder.id = "run" + recorder.client = mock.Mock() + return recorder + + +def test_file_uri_lock_path_stays_absolute(tmp_path: Path) -> None: + mlruns = (tmp_path / "mlruns").resolve() + + assert _file_uri_to_path(mlruns.as_uri()) == mlruns + assert _file_uri_to_path("file:" + str(mlruns)) == mlruns + + +def test_log_uncommitted_code_skips_non_git_cwd() -> None: + recorder = _make_recorder() + not_git = subprocess.CalledProcessError(128, ["git", "rev-parse"]) + + with ( + mock.patch( + "qlib.workflow.recorder.subprocess.run", + side_effect=not_git, + ) as run, + mock.patch( + "qlib.workflow.recorder.subprocess.check_output" + ) as check_output, + ): + recorder._log_uncommitted_code() + + run.assert_called_once_with( + ["git", "rev-parse", "--is-inside-work-tree"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=True, + ) + check_output.assert_not_called() + recorder.client.log_text.assert_not_called() + + +def test_log_uncommitted_code_uses_git_without_shell() -> None: + recorder = _make_recorder() + in_worktree = subprocess.CompletedProcess(["git"], 0, stdout=b"true\n") + + with ( + mock.patch( + "qlib.workflow.recorder.subprocess.run", + return_value=in_worktree, + ), + mock.patch( + "qlib.workflow.recorder.subprocess.check_output", + side_effect=[b"diff", b"status", b"cached"], + ) as check_output, + ): + recorder._log_uncommitted_code() + + assert [call.args[0] for call in check_output.call_args_list] == [ + ["git", "diff"], + ["git", "status"], + ["git", "diff", "--cached"], + ] + assert all( + call.kwargs == {"stderr": subprocess.DEVNULL} + for call in check_output.call_args_list + ) + assert recorder.client.log_text.call_count == 3