diff --git a/huggingface_sb3/naming_schemes.py b/huggingface_sb3/naming_schemes.py index 5d15e57..fcdef8d 100644 --- a/huggingface_sb3/naming_schemes.py +++ b/huggingface_sb3/naming_schemes.py @@ -20,7 +20,17 @@ class EnvironmentName(str): """ def __new__(cls, gym_id: str): - normalized_name = super().__new__(cls, gym_id.replace("/", "-")) + normalized_str = gym_id.replace("/", "-") + if ":" in normalized_str: + split_by_colon = normalized_str.split(":") + if len(split_by_colon) == 2: + # split by colon and take the first part + normalized_str = split_by_colon[1] + else: + raise ValueError( + f"Environment name {gym_id} contains more than one colon!" + ) + normalized_name = super().__new__(cls, normalized_str) normalized_name._gym_id = gym_id return normalized_name diff --git a/setup.py b/setup.py index 28b685c..88aed0f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ "pyyaml~=6.0", "wasabi", "numpy", - "cloudpickle>=1.6" + "cloudpickle>=1.6", ] extras = {} diff --git a/tests/test_naming_scheme.py b/tests/test_naming_scheme.py index dbac601..d30b7b0 100644 --- a/tests/test_naming_scheme.py +++ b/tests/test_naming_scheme.py @@ -3,7 +3,7 @@ from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId -@pytest.fixture(params=["seals/Walker2d-v0", "LunarLander-v2"]) +@pytest.fixture(params=["seals/Walker2d-v0", "LunarLander-v2", "seals:seals/Walker2d-v0"]) def env_id(request) -> str: return request.param @@ -23,12 +23,29 @@ def repo_id(model_name: ModelName) -> ModelRepoId: return ModelRepoId("orga", model_name) -def test_that_slashes_are_removed(env_id: str, env_name: EnvironmentName, model_name: ModelName, repo_id: ModelRepoId): +def test_that_slashes_are_removed(env_name: EnvironmentName, model_name: ModelName, repo_id: ModelRepoId): assert "/" not in env_name assert "/" not in model_name assert "/" not in model_name.filename assert repo_id.count("/") == 1 # note: repo id has exactly one slash separating org from repo name +def test_that_colon_is_removed(env_name: EnvironmentName, model_name: ModelName, repo_id: ModelRepoId): + assert ":" not in env_name + assert ":" not in model_name + assert ":" not in model_name.filename + assert ":" not in repo_id + + +def test_that_package_before_colon_is_removed(): + env_name = EnvironmentName("seals:seals/Walker2d-v0") + assert env_name == "seals-Walker2d-v0" + + +def test_that_double_colon_is_rejected(): + with pytest.raises(ValueError): + EnvironmentName("seals:seals:Walker2d-v0") + + def test_that_gym_id_is_preserved(env_id: str, env_name: EnvironmentName): assert env_name.gym_id == env_id