diff --git a/mujoco_urdf_loader/__init__.py b/mujoco_urdf_loader/__init__.py index 32f88a5..7e4db09 100644 --- a/mujoco_urdf_loader/__init__.py +++ b/mujoco_urdf_loader/__init__.py @@ -1,5 +1,6 @@ from .loader import ( ControlMode, + EqualityConstraintCfg, FrameQuatSensorCfg, GyroSensorCfg, URDFtoMuJoCoLoader, diff --git a/mujoco_urdf_loader/loader.py b/mujoco_urdf_loader/loader.py index 85b8f27..c924128 100644 --- a/mujoco_urdf_loader/loader.py +++ b/mujoco_urdf_loader/loader.py @@ -16,6 +16,7 @@ add_framequat_sensor, add_gyro_sensor, add_camera_to_site, + add_equality_constraints_for_sites, convert_hinge_to_ball_joints, ) from mujoco_urdf_loader.urdf_fcn import ( @@ -51,6 +52,13 @@ class CameraCfg: fovy: float name: str +@dataclasses.dataclass +class EqualityConstraintCfg: + """Configuration for a connect/weld equality constraint between two sites.""" + site1: str + site2: str + constraint_type: str = "connect" + @dataclasses.dataclass class URDFtoMuJoCoLoaderCfg: controlled_joints: List[str] @@ -62,6 +70,7 @@ class URDFtoMuJoCoLoaderCfg: framequat_sensors_cfg: Union[None, List[Union[FrameQuatSensorCfg, Dict[str, Any]]]] = None gyro_sensors_cfg: Union[None, List[Union[GyroSensorCfg, Dict[str, Any]]]] = None cameras_cfg: Union[None, List[Union[CameraCfg, Dict[str, Any]]]] = None + equality_constraints_cfg: Union[None, List[Union[EqualityConstraintCfg, Dict[str, Any]]]] = None ball_joint_damping: float = 0.0 ball_joint_armature: float = 0.0 ball_joint_frictionloss: float = 0.0 @@ -168,6 +177,7 @@ def load_urdf(urdf_path: str, mesh_path: str, cfg: URDFtoMuJoCoLoaderCfg): all_missing_joints_as_sites=cfg.all_missing_joints_as_sites, framequat_sensors_cfg=cfg.framequat_sensors_cfg, gyro_sensors_cfg=cfg.gyro_sensors_cfg, + equality_constraints_cfg=cfg.equality_constraints_cfg, ) else: mjcf_cfg = cfg @@ -189,6 +199,7 @@ def load_urdf(urdf_path: str, mesh_path: str, cfg: URDFtoMuJoCoLoaderCfg): loader.add_framequat_sensors(cfg.framequat_sensors_cfg) loader.add_gyro_sensors(cfg.gyro_sensors_cfg) loader.add_cameras(cfg.cameras_cfg) + loader.add_equality_constraints(cfg.equality_constraints_cfg) return loader @staticmethod @@ -306,6 +317,66 @@ def add_cameras( fovy=normalized_cfg.fovy, ) + @staticmethod + def _normalize_equality_constraint_cfg( + eq_cfg: Union[EqualityConstraintCfg, Dict[str, Any]], + ) -> EqualityConstraintCfg: + if isinstance(eq_cfg, EqualityConstraintCfg): + return eq_cfg + + if not isinstance(eq_cfg, dict): + raise TypeError( + "Each equality constraint configuration must be an " + "EqualityConstraintCfg or a dict with keys site1 and site2." + ) + + site1 = eq_cfg.get("site1") + site2 = eq_cfg.get("site2") + constraint_type = eq_cfg.get("constraint_type", "connect") + + if site1 is None or site2 is None: + raise ValueError( + "Each equality constraint configuration requires site1 and site2." + ) + + return EqualityConstraintCfg( + site1=site1, site2=site2, constraint_type=constraint_type, + ) + + def add_equality_constraints( + self, + equality_constraints_cfg: Union[ + None, List[Union[EqualityConstraintCfg, Dict[str, Any]]] + ] = None, + ): + """Add equality constraints (connect/weld) to the MJCF model. + + Uses the existing ``add_equality_constraints_for_sites`` helper to + create ```` or ```` elements inside ````. + + Args: + equality_constraints_cfg: List of ``EqualityConstraintCfg`` + dataclasses or dicts with keys ``site1``, ``site2``, and + optionally ``constraint_type`` (default ``"connect"``). + If ``None``, no constraints are added. + """ + if equality_constraints_cfg is None: + return + + # Group by constraint_type so we can call the helper once per type + by_type: Dict[str, List[tuple]] = {} + for cfg in equality_constraints_cfg: + normalized = self._normalize_equality_constraint_cfg(cfg) + ctype = normalized.constraint_type + by_type.setdefault(ctype, []).append( + (normalized.site1, normalized.site2) + ) + + for constraint_type, site_pairs in by_type.items(): + add_equality_constraints_for_sites( + self.mjcf, site_pairs, constraint_type=constraint_type, + ) + @staticmethod def get_missing_joint_sites( robot_urdf: ET.Element, diff --git a/mujoco_urdf_loader/mjcf_fcn.py b/mujoco_urdf_loader/mjcf_fcn.py index 2d8bd48..4fe6849 100644 --- a/mujoco_urdf_loader/mjcf_fcn.py +++ b/mujoco_urdf_loader/mjcf_fcn.py @@ -87,6 +87,7 @@ def add_position_actuator( return mjcf + def add_torque_actuator( mjcf: ET.Element, joint: str, @@ -281,12 +282,12 @@ def add_framequat_sensor(mjcf: ET.Element, objname: str, objtype: str = 'site', def add_joint_eq( - mjcf: ET.Element, - joint1: str, - joint2: str, - name: str = None, - multiplier: float = 1.0, - offset: float = 0.0 + mjcf: ET.Element, + joint1: str, + joint2: str, + name: str = None, + multiplier: float = 1.0, + offset: float = 0.0, ) -> ET.Element: """Add a joint equality constraint between two joints. @@ -575,3 +576,66 @@ def convert_hinge_to_ball_joints( joint_elem.set("armature", str(armature)) joint_elem.set("frictionloss", str(frictionloss)) return mjcf + +def add_equality_constraints_for_sites( + mjcf: ET.Element, site_pairs: List[tuple], constraint_type: str = "connect" +) -> ET.Element: + """ + Add equality constraints between pairs of sites in MJCF. + + Args: + mjcf (ET.Element): The MJCF file as ElementTree. + site_pairs (List[tuple]): List of tuples with (site1_name, site2_name) to connect. + constraint_type (str): Type of constraint - "connect" or "weld" (default: "connect"). + + Returns: + ET.Element: The modified MJCF file. + """ + # Find or create the equality element + equality = mjcf.find("equality") + if equality is None: + equality = ET.SubElement(mjcf, "equality") + + for site1, site2 in site_pairs: + # Verify both sites exist + site1_elem = mjcf.find(f".//site[@name='{site1}']") + site2_elem = mjcf.find(f".//site[@name='{site2}']") + + if site1_elem is None: + raise ValueError(f"Site {site1} not found in MJCF") + if site2_elem is None: + raise ValueError(f"Site {site2} not found in MJCF") + + # Create the equality constraint + if constraint_type == "connect": + # Connect constraint directly references sites (no anchor needed for sites) + constraint = ET.SubElement(equality, "connect") + constraint.set("site1", site1) + constraint.set("site2", site2) + elif constraint_type == "weld": + # Weld constraint references bodies + # Find parent bodies of the sites + body1 = None + body2 = None + for body in mjcf.findall(".//body"): + if body.find(f".//site[@name='{site1}']") is not None: + body1 = body.attrib.get("name") + if body.find(f".//site[@name='{site2}']") is not None: + body2 = body.attrib.get("name") + + if body1 is None or body2 is None: + raise ValueError( + f"Could not find parent bodies for sites {site1} and {site2}" + ) + + constraint = ET.SubElement(equality, "weld") + constraint.set("body1", body1) + constraint.set("body2", body2) + else: + raise ValueError(f"Unknown constraint type: {constraint_type}") + + print( + f"Created {constraint_type} equality constraint between {site1} and {site2}" + ) + + return mjcf diff --git a/tests/test_loader_equality_constraints_cfg.py b/tests/test_loader_equality_constraints_cfg.py new file mode 100644 index 0000000..c21daec --- /dev/null +++ b/tests/test_loader_equality_constraints_cfg.py @@ -0,0 +1,117 @@ +import xml.etree.ElementTree as ET + +import pytest + +from mujoco_urdf_loader.loader import ( + EqualityConstraintCfg, + URDFtoMuJoCoLoader, + URDFtoMuJoCoLoaderCfg, +) + + +def _make_empty_mjcf() -> ET.Element: + return ET.fromstring( + """ + + + + + + + + + + + """ + ) + + +def test_add_equality_constraints_none_keeps_model_unchanged(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + loader.add_equality_constraints(None) + + assert loader.mjcf.find(".//equality") is None + + +def test_add_equality_constraints_accepts_list_of_dataclasses(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + loader.add_equality_constraints( + [EqualityConstraintCfg(site1="site_a", site2="site_b")] + ) + + connect = loader.mjcf.find(".//equality/connect") + assert connect is not None + assert connect.attrib["site1"] == "site_a" + assert connect.attrib["site2"] == "site_b" + + +def test_add_equality_constraints_accepts_dict(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + loader.add_equality_constraints( + [{"site1": "site_a", "site2": "site_b"}] + ) + + connect = loader.mjcf.find(".//equality/connect") + assert connect is not None + assert connect.attrib["site1"] == "site_a" + assert connect.attrib["site2"] == "site_b" + + +def test_add_equality_constraints_multiple(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + loader.add_equality_constraints( + [ + EqualityConstraintCfg(site1="site_a", site2="site_b"), + EqualityConstraintCfg(site1="site_b", site2="site_a"), + ] + ) + + connects = loader.mjcf.findall(".//equality/connect") + assert len(connects) == 2 + assert connects[0].attrib["site1"] == "site_a" + assert connects[1].attrib["site1"] == "site_b" + + +def test_add_equality_constraints_weld_type(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + loader.add_equality_constraints( + [EqualityConstraintCfg(site1="site_a", site2="site_b", constraint_type="weld")] + ) + + weld = loader.mjcf.find(".//equality/weld") + assert weld is not None + assert weld.attrib["body1"] == "body_a" + assert weld.attrib["body2"] == "body_b" + + +def test_add_equality_constraints_raises_on_missing_fields(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + with pytest.raises(ValueError): + loader.add_equality_constraints([{"site1": "site_a"}]) + + +def test_add_equality_constraints_raises_on_invalid_type(): + loader = URDFtoMuJoCoLoader( + _make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[]) + ) + + with pytest.raises(TypeError): + loader.add_equality_constraints(["not_a_valid_config"]) diff --git a/tests/test_mjcf_fcn.py b/tests/test_mjcf_fcn.py index c43815d..31b86f3 100644 --- a/tests/test_mjcf_fcn.py +++ b/tests/test_mjcf_fcn.py @@ -4,6 +4,7 @@ from mujoco_urdf_loader.mjcf_fcn import ( add_camera, + add_equality_constraints_for_sites, add_joint_eq, add_joint_pos_sensor, add_joint_vel_sensor, @@ -290,3 +291,87 @@ def test_convert_hinge_to_ball_empty_map(): # Original joints should be unchanged assert mjcf.find(".//joint[@name='spherical_rev_my_rod_x']") is not None assert mjcf.find(".//joint[@name='regular_hinge']") is not None + + +# --------------------------------------------------------------------------- +# Tests for add_equality_constraints_for_sites +# --------------------------------------------------------------------------- + +def _make_mjcf_with_sites(): + """Build a minimal MJCF with sites for equality constraint tests.""" + mjcf = ET.Element("mujoco") + worldbody = ET.SubElement(mjcf, "worldbody") + body_a = ET.SubElement(worldbody, "body") + body_a.set("name", "body_a") + site_a = ET.SubElement(body_a, "site") + site_a.set("name", "site_a") + site_a.set("pos", "0 0 0") + + body_b = ET.SubElement(worldbody, "body") + body_b.set("name", "body_b") + site_b = ET.SubElement(body_b, "site") + site_b.set("name", "site_b") + site_b.set("pos", "0.1 0.2 0.3") + + site_c = ET.SubElement(body_b, "site") + site_c.set("name", "site_c") + site_c.set("pos", "0.4 0.5 0.6") + + return mjcf + + +def test_add_equality_constraints_connect(): + mjcf = _make_mjcf_with_sites() + + mjcf = add_equality_constraints_for_sites(mjcf, [("site_a", "site_b")]) + + assert len(mjcf.findall(".//equality")) == 1 + connects = mjcf.findall(".//equality/connect") + assert len(connects) == 1 + assert connects[0].attrib["site1"] == "site_a" + assert connects[0].attrib["site2"] == "site_b" + + +def test_add_equality_constraints_multiple(): + mjcf = _make_mjcf_with_sites() + + mjcf = add_equality_constraints_for_sites( + mjcf, [("site_a", "site_b"), ("site_b", "site_c")] + ) + + # Only one element should exist + assert len(mjcf.findall(".//equality")) == 1 + connects = mjcf.findall(".//equality/connect") + assert len(connects) == 2 + assert connects[0].attrib["site1"] == "site_a" + assert connects[1].attrib["site1"] == "site_b" + assert connects[1].attrib["site2"] == "site_c" + + +def test_add_equality_constraints_weld(): + mjcf = _make_mjcf_with_sites() + + mjcf = add_equality_constraints_for_sites( + mjcf, [("site_a", "site_b")], constraint_type="weld" + ) + + welds = mjcf.findall(".//equality/weld") + assert len(welds) == 1 + assert welds[0].attrib["body1"] == "body_a" + assert welds[0].attrib["body2"] == "body_b" + + +def test_add_equality_constraints_missing_site(): + mjcf = _make_mjcf_with_sites() + + with pytest.raises(ValueError, match="nonexistent"): + add_equality_constraints_for_sites(mjcf, [("site_a", "nonexistent")]) + + +def test_add_equality_constraints_unknown_type(): + mjcf = _make_mjcf_with_sites() + + with pytest.raises(ValueError, match="Unknown constraint type"): + add_equality_constraints_for_sites( + mjcf, [("site_a", "site_b")], constraint_type="invalid" + )