Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

from urllib.parse import urlparse
from urllib.request import url2pathname
import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.entities import ViewType
import os
from typing import Optional, Text
from pathlib import Path

Expand All @@ -19,6 +19,14 @@
logger = get_module_logger("workflow")


def _file_uri_to_path(uri: str) -> Path:
pr = urlparse(uri)
path = url2pathname(pr.path)
if pr.netloc and pr.netloc != "localhost":
path = f"//{pr.netloc}{path}"
return Path(path)


class ExpManager:
"""
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
Expand Down Expand Up @@ -233,7 +241,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
with FileLock(_file_uri_to_path(self.uri) / "filelock"): # pylint: disable=E0110
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
Expand Down
49 changes: 32 additions & 17 deletions qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,23 +359,38 @@ def start_run(self):
) # Log necessary environment variables
return run

def _log_uncommitted_code(self):
"""
Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes.
So this tries to automatically to log them all.
"""
# TODO: the sub-directories maybe git repos.
# So it will be better if we can walk the sub-directories and log the uncommitted changes.
for cmd, fname in [
("git diff", "code_diff.txt"),
("git status", "code_status.txt"),
("git diff --cached", "code_cached.txt"),
]:
try:
out = subprocess.check_output(cmd, shell=True)
self.client.log_text(self.id, out.decode(), fname) # this behaves same as above
except subprocess.CalledProcessError:
logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.")
def _log_uncommitted_code(self):
"""
Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes.
So this tries to automatically to log them all.
"""
try:
result = subprocess.run(
["git", "rev-parse", "--is-inside-work-tree"],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=True,
)
except (subprocess.CalledProcessError, FileNotFoundError):
return

if result.stdout.decode().strip().lower() != "true":
return

# TODO: the sub-directories maybe git repos.
# So it will be better if we can walk the sub-directories and log the uncommitted changes.
for cmd, fname in [
(["git", "diff"], "code_diff.txt"),
(["git", "status"], "code_status.txt"),
(["git", "diff", "--cached"], "code_cached.txt"),
]:
try:
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
self.client.log_text(self.id, out.decode(), fname) # this behaves same as above
except (subprocess.CalledProcessError, FileNotFoundError):
logger.debug(
f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}."
)

def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Expand Down
76 changes: 76 additions & 0 deletions tests/test_workflow_cwd_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import subprocess
from pathlib import Path
from unittest import mock

from qlib.workflow.expm import _file_uri_to_path
from qlib.workflow.recorder import MLflowRecorder


def _make_recorder() -> MLflowRecorder:
recorder = MLflowRecorder.__new__(MLflowRecorder)
recorder.id = "run"
recorder.client = mock.Mock()
return recorder


def test_file_uri_lock_path_stays_absolute(tmp_path: Path) -> None:
mlruns = (tmp_path / "mlruns").resolve()

assert _file_uri_to_path(mlruns.as_uri()) == mlruns
assert _file_uri_to_path("file:" + str(mlruns)) == mlruns


def test_log_uncommitted_code_skips_non_git_cwd() -> None:
recorder = _make_recorder()
not_git = subprocess.CalledProcessError(128, ["git", "rev-parse"])

with (
mock.patch(
"qlib.workflow.recorder.subprocess.run",
side_effect=not_git,
) as run,
mock.patch(
"qlib.workflow.recorder.subprocess.check_output"
) as check_output,
):
recorder._log_uncommitted_code()

run.assert_called_once_with(
["git", "rev-parse", "--is-inside-work-tree"],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=True,
)
check_output.assert_not_called()
recorder.client.log_text.assert_not_called()


def test_log_uncommitted_code_uses_git_without_shell() -> None:
recorder = _make_recorder()
in_worktree = subprocess.CompletedProcess(["git"], 0, stdout=b"true\n")

with (
mock.patch(
"qlib.workflow.recorder.subprocess.run",
return_value=in_worktree,
),
mock.patch(
"qlib.workflow.recorder.subprocess.check_output",
side_effect=[b"diff", b"status", b"cached"],
) as check_output,
):
recorder._log_uncommitted_code()

assert [call.args[0] for call in check_output.call_args_list] == [
["git", "diff"],
["git", "status"],
["git", "diff", "--cached"],
]
assert all(
call.kwargs == {"stderr": subprocess.DEVNULL}
for call in check_output.call_args_list
)
assert recorder.client.log_text.call_count == 3