diff --git a/src/pydantify_common/__init__.py b/src/pydantify_common/__init__.py index c4d3fe2..f5b0dbc 100644 --- a/src/pydantify_common/__init__.py +++ b/src/pydantify_common/__init__.py @@ -1,3 +1,4 @@ -from .model import PydantifyModel +from .model import PydantifyModel, XMLPydantifyModel +from .helper import model_dump_xml_string, NETCONF_BASE_NS -__all__ = ["PydantifyModel"] +__all__ = ["PydantifyModel", "XMLPydantifyModel", "model_dump_xml_string", "NETCONF_BASE_NS"] diff --git a/src/pydantify_common/helper.py b/src/pydantify_common/helper.py index 689bbfc..2976dc4 100644 --- a/src/pydantify_common/helper.py +++ b/src/pydantify_common/helper.py @@ -1,14 +1,31 @@ from lxml import etree from .model import XMLPydantifyModel -from typing import Any + +NETCONF_BASE_NS = "urn:ietf:params:xml:ns:netconf:base:1.0" def model_dump_xml_string( model: XMLPydantifyModel, *, pretty_print: bool = False, data_root: bool = False ) -> str: - data = model.model_dump_xml() + """ + Serialize model to XML string. + + Args: + model: The XMLPydantifyModel to serialize + pretty_print: If True, format with indentation + data_root: If True, wrap in root + + Returns: + XML string representation + """ + xml_element = model.model_dump_xml() + if data_root: # Add `` root element - pass + data_elem = etree.Element( + f"{{{NETCONF_BASE_NS}}}data", nsmap={None: NETCONF_BASE_NS} + ) + data_elem.append(xml_element) + xml_element = data_elem - return etree.tostring(data, encoding=str, pretty_print=pretty_print) + return etree.tostring(xml_element, encoding=str, pretty_print=pretty_print) diff --git a/src/pydantify_common/model.py b/src/pydantify_common/model.py index 179cfb3..8f7f3c8 100644 --- a/src/pydantify_common/model.py +++ b/src/pydantify_common/model.py @@ -12,5 +12,152 @@ class XMLPydantifyModel(PydantifyModel): namespace: ClassVar[str] prefix: ClassVar[str] - def model_dump_xml(self) -> etree.Element: - pass + def model_dump_xml(self, parent_namespace: str | None = None) -> etree._Element: + """ + Serialize model to lxml Element with namespace support. + + Args: + parent_namespace: The namespace of the parent element. Used to determine + whether to declare xmlns attribute on this element. + + Returns: + lxml Element representing this model. + """ + # Check if this is a wrapper model (single XMLPydantifyModel field) + # If so, delegate to that child model + wrapper_child = self._get_wrapper_child() + if wrapper_child is not None: + return wrapper_child.model_dump_xml(parent_namespace=parent_namespace) + + # 1. Get local name from first field's alias + local_name = self._get_xml_local_name() + + # 2. Build qualified tag using Clark notation + tag = f"{{{self.namespace}}}{local_name}" + + # 3. Only add nsmap if namespace differs from parent + nsmap = None + if self.namespace != parent_namespace: + nsmap = {None: self.namespace} # Default namespace declaration + + # 4. Create element + element = etree.Element(tag, nsmap=nsmap) + + # 5. Iterate through model fields and add children + self._serialize_fields_to_element(element) + + return element + + def _get_wrapper_child(self) -> "XMLPydantifyModel | None": + """ + Check if this model is a "wrapper" with a single XMLPydantifyModel field. + + Returns: + The child XMLPydantifyModel if this is a wrapper, otherwise None. + """ + fields = self.__class__.model_fields + if len(fields) != 1: + return None + + field_name = next(iter(fields.keys())) + value = getattr(self, field_name) + + if isinstance(value, XMLPydantifyModel): + return value + + return None + + def _serialize_fields_to_element(self, element: etree._Element) -> None: + """ + Serialize all model fields as child elements. + + Args: + element: The parent element to add children to. + """ + for field_name, field_info in self.__class__.model_fields.items(): + value = getattr(self, field_name) + if value is None: + continue + + # Get local name from alias (e.g., "config:name" → "name") + alias = field_info.alias or field_name + child_local_name = alias.split(":")[-1] if ":" in alias else alias + + if isinstance(value, XMLPydantifyModel): + # Nested model - create element from field alias and namespace from child + child_elem = self._create_child_model_element( + child_local_name, value, self.namespace + ) + element.append(child_elem) + elif isinstance(value, list): + for item in value: + if isinstance(item, XMLPydantifyModel): + child_elem = self._create_child_model_element( + child_local_name, item, self.namespace + ) + element.append(child_elem) + else: + # Primitive list item + child_tag = f"{{{self.namespace}}}{child_local_name}" + child_elem = etree.SubElement(element, child_tag) + child_elem.text = str(item) + else: + # Primitive value + child_tag = f"{{{self.namespace}}}{child_local_name}" + child_elem = etree.SubElement(element, child_tag) + child_elem.text = str(value) + + def _create_child_model_element( + self, + local_name: str, + child_model: "XMLPydantifyModel", + parent_namespace: str + ) -> etree._Element: + """ + Create an element for a child XMLPydantifyModel using the parent's field alias + for the element name. + + Args: + local_name: The local name for the element (from parent's field alias). + child_model: The child XMLPydantifyModel to serialize. + parent_namespace: The namespace of the parent element. + + Returns: + lxml Element with the child's content. + """ + # Use child's namespace for the tag + tag = f"{{{child_model.namespace}}}{local_name}" + + # Only add nsmap if child's namespace differs from parent's + nsmap = None + if child_model.namespace != parent_namespace: + nsmap = {None: child_model.namespace} + + # Create element + element = etree.Element(tag, nsmap=nsmap) + + # Let the child model fill in its children + child_model._serialize_fields_to_element(element) + + return element + + def _get_xml_local_name(self) -> str: + """ + Extract the local name for this model's XML element. + + Uses the first field's alias to extract the local name. + Falls back to the lowercase class name if no fields exist. + """ + # Try to get local name from first field's alias + if self.__class__.model_fields: + first_field_info = next(iter(self.__class__.model_fields.values())) + alias = first_field_info.alias + if alias: + # Extract prefix from "prefix:localname" pattern + # e.g., "configuration:devicename" → "configuration" + parts = alias.split(":") + if len(parts) >= 1: + return parts[0] + + # Fallback to lowercase class name + return self.__class__.__name__.lower() diff --git a/tests/examples/with_augment/model.py b/tests/examples/with_augment/model.py index 5e30dde..a2f64bd 100644 --- a/tests/examples/with_augment/model.py +++ b/tests/examples/with_augment/model.py @@ -2,7 +2,7 @@ from typing import Annotated, List, ClassVar -from pydantic import ConfigDict, Field, RootModel +from pydantic import ConfigDict, Field from pydantify_common.model import XMLPydantifyModel diff --git a/tests/examples/with_import_uses/model.py b/tests/examples/with_import_uses/model.py index cd16020..04da6e3 100644 --- a/tests/examples/with_import_uses/model.py +++ b/tests/examples/with_import_uses/model.py @@ -2,7 +2,7 @@ from typing import Annotated, List, ClassVar -from pydantic import ConfigDict, Field, RootModel +from pydantic import ConfigDict, Field from pydantify_common.model import XMLPydantifyModel diff --git a/tests/test_examples.py b/tests/test_examples.py index d3e9a75..692e9eb 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,5 +1,4 @@ from pathlib import Path -from pydantify_common.model import XMLPydantifyModel from pydantify_common.helper import model_dump_xml_string from lxml import etree diff --git a/tests/test_model_dump_xml.py b/tests/test_model_dump_xml.py new file mode 100644 index 0000000..a4c5d07 --- /dev/null +++ b/tests/test_model_dump_xml.py @@ -0,0 +1,324 @@ +"""Unit tests for XMLPydantifyModel.model_dump_xml() and model_dump_xml_string().""" + +from typing import Annotated, ClassVar, List + +from lxml import etree +from pydantic import ConfigDict, Field + +from pydantify_common import XMLPydantifyModel, model_dump_xml_string, NETCONF_BASE_NS + + +class TestLocalNameExtraction: + """Test extracting local name from aliases.""" + + def test_local_name_from_alias_with_prefix(self): + """Local name should be extracted from prefix:name alias pattern.""" + + class TestModel(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "test" + name: Annotated[str, Field(alias="test:name")] + + model = TestModel(name="value") + element = model.model_dump_xml() + + assert element.tag == "{http://example.com/ns}test" + + def test_fallback_to_class_name(self): + """Should fallback to lowercase class name when no alias.""" + + class SimpleModel(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "simple" + value: str + + model = SimpleModel(value="test") + element = model.model_dump_xml() + + assert element.tag == "{http://example.com/ns}simplemodel" + + +class TestSimpleModel: + """Test single-level model serialization.""" + + def test_simple_string_field(self): + """String fields should serialize as text content.""" + + class Device(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/device" + prefix: ClassVar[str] = "device" + name: Annotated[str, Field(alias="device:name")] + + model = Device(name="router1") + element = model.model_dump_xml() + + assert element.tag == "{http://example.com/device}device" + assert element[0].tag == "{http://example.com/device}name" + assert element[0].text == "router1" + + def test_multiple_fields(self): + """Multiple fields should be serialized in order.""" + + class Interface(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/if" + prefix: ClassVar[str] = "if" + name: Annotated[str, Field(alias="if:name")] + ip: Annotated[str, Field(alias="if:ip")] + + model = Interface(name="eth0", ip="192.168.1.1") + element = model.model_dump_xml() + + assert len(element) == 2 + assert element[0].tag == "{http://example.com/if}name" + assert element[0].text == "eth0" + assert element[1].tag == "{http://example.com/if}ip" + assert element[1].text == "192.168.1.1" + + def test_namespace_declaration(self): + """Namespace should be declared in nsmap.""" + + class TestModel(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/test" + prefix: ClassVar[str] = "test" + value: Annotated[str, Field(alias="test:value")] + + model = TestModel(value="x") + element = model.model_dump_xml() + + assert element.nsmap.get(None) == "http://example.com/test" + + +class TestNestedModels: + """Test nested model serialization with same/different namespaces.""" + + def test_nested_different_namespace(self): + """Nested model with different namespace should declare xmlns.""" + + class Inner(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/inner" + prefix: ClassVar[str] = "inner" + value: Annotated[str, Field(alias="inner:value")] + + class Outer(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/outer" + prefix: ClassVar[str] = "outer" + name: Annotated[str, Field(alias="outer:name")] + inner: Annotated[Inner, Field(alias="outer:inner")] + + model = Outer(name="test", inner=Inner(value="nested")) + element = model.model_dump_xml() + + # Find the inner element + inner_elem = element.find("{http://example.com/inner}inner") + assert inner_elem is not None + assert inner_elem.nsmap.get(None) == "http://example.com/inner" + + def test_nested_same_namespace(self): + """Nested model with same namespace should NOT re-declare xmlns.""" + + class Inner(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/same" + prefix: ClassVar[str] = "same" + value: Annotated[str, Field(alias="same:value")] + + class Outer(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/same" + prefix: ClassVar[str] = "same" + name: Annotated[str, Field(alias="same:name")] + inner: Annotated[Inner, Field(alias="same:inner")] + + model = Outer(name="test", inner=Inner(value="nested")) + element = model.model_dump_xml() + + # Serialize to string and check that xmlns is only declared once + xml_string = etree.tostring(element, encoding="unicode") + + # Should contain exactly one xmlns declaration (on root element) + assert xml_string.count('xmlns="http://example.com/same"') == 1 + + +class TestListFields: + """Test list of primitives and list of models.""" + + def test_list_of_models(self): + """List of XMLPydantifyModel should create multiple elements.""" + + class Item(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + + class Container(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + items: Annotated[List[Item], Field(alias="ns:items")] + + model = Container(items=[Item(name="a"), Item(name="b"), Item(name="c")]) + element = model.model_dump_xml() + + items = element.findall("{http://example.com/ns}items") + assert len(items) == 3 + assert items[0][0].text == "a" + assert items[1][0].text == "b" + assert items[2][0].text == "c" + + def test_list_of_primitives(self): + """List of primitives should create multiple text elements.""" + + class Tags(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + tag: Annotated[List[str], Field(alias="ns:tag")] + + model = Tags(tag=["red", "green", "blue"]) + element = model.model_dump_xml() + + tags = element.findall("{http://example.com/ns}tag") + assert len(tags) == 3 + assert [t.text for t in tags] == ["red", "green", "blue"] + + +class TestOptionalFields: + """Test None values are skipped.""" + + def test_none_values_skipped(self): + """Fields with None value should not appear in XML.""" + + class Device(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + description: Annotated[str | None, Field(alias="ns:description")] = None + + model = Device(name="router1") + element = model.model_dump_xml() + + assert len(element) == 1 + assert element[0].tag == "{http://example.com/ns}name" + + def test_optional_with_value(self): + """Optional field with value should appear in XML.""" + + class Device(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + description: Annotated[str | None, Field(alias="ns:description")] = None + + model = Device(name="router1", description="Main router") + element = model.model_dump_xml() + + assert len(element) == 2 + assert element[1].tag == "{http://example.com/ns}description" + assert element[1].text == "Main router" + + +class TestDataRootWrapper: + """Test NETCONF data root wrapping.""" + + def test_data_root_wrapping(self): + """data_root=True should wrap in NETCONF data element.""" + + class Config(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + + model = Config(name="test") + xml_string = model_dump_xml_string(model, data_root=True) + + tree = etree.fromstring(xml_string.encode()) + assert tree.tag == f"{{{NETCONF_BASE_NS}}}data" + assert tree.nsmap.get(None) == NETCONF_BASE_NS + assert len(tree) == 1 + assert tree[0].tag == "{http://example.com/ns}ns" + + def test_without_data_root(self): + """data_root=False should not wrap.""" + + class Config(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + + model = Config(name="test") + xml_string = model_dump_xml_string(model, data_root=False) + + tree = etree.fromstring(xml_string.encode()) + assert tree.tag == "{http://example.com/ns}ns" + + +class TestWrapperModelPassthrough: + """Test wrapper model with single XMLPydantifyModel field.""" + + def test_wrapper_delegates_to_child(self): + """Wrapper model should delegate serialization to single child.""" + + class Inner(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + value: Annotated[str, Field(alias="ns:value")] + + class Wrapper(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + inner: Annotated[Inner, Field(alias="ns:inner")] + + model = Wrapper(inner=Inner(value="test")) + element = model.model_dump_xml() + + # Should produce (from Inner, not from Wrapper) + # with test as child + assert element.tag == "{http://example.com/ns}ns" + assert element[0].tag == "{http://example.com/ns}value" + assert element[0].text == "test" + + +class TestPrettyPrint: + """Test pretty printing output.""" + + def test_pretty_print_true(self): + """pretty_print=True should include newlines and indentation.""" + + class Config(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + + model = Config(name="test") + xml_string = model_dump_xml_string(model, pretty_print=True) + + assert "\n" in xml_string + + def test_pretty_print_false(self): + """pretty_print=False should be compact.""" + + class Config(XMLPydantifyModel): + model_config = ConfigDict(populate_by_name=True) + namespace: ClassVar[str] = "http://example.com/ns" + prefix: ClassVar[str] = "ns" + name: Annotated[str, Field(alias="ns:name")] + + model = Config(name="test") + xml_string = model_dump_xml_string(model, pretty_print=False) + + assert "\n" not in xml_string