diff --git a/pyproject.toml b/pyproject.toml index 01711eed5..bb3803cd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,11 +86,10 @@ dev = [ test = [ "inline-snapshot>=0.31.1", "msgspec>=0.18", - "pytest>=6.1", "pytest>=8.3.4", - "pytest-cov>=2.12.1", "pytest-cov>=5", "pytest-mock>=3.14", + "pytest-timeout>=2.4", "pytest-xdist>=3.3.1", "time-machine>=3.1", "watchfiles>=1.1", diff --git a/src/datamodel_code_generator/model/base.py b/src/datamodel_code_generator/model/base.py index 7110f58bc..75bb109a7 100644 --- a/src/datamodel_code_generator/model/base.py +++ b/src/datamodel_code_generator/model/base.py @@ -223,7 +223,7 @@ def _build_union_type_hint(self) -> str | None: """Build Union[] type hint from data_type.data_types if forward reference requires it.""" if not (self._use_union_operator != self.data_type.use_union_operator and self.data_type.is_union): return None - parts = [dt.type_hint for dt in self.data_type.data_types if dt.type_hint] + parts = dict.fromkeys(dt.type_hint for dt in self.data_type.data_types if dt.type_hint).keys() if len(parts) > 1: return f"Union[{', '.join(parts)}]" return None # pragma: no cover @@ -232,7 +232,7 @@ def _build_base_union_type_hint(self) -> str | None: # pragma: no cover """Build Union[] base type hint from data_type.data_types if forward reference requires it.""" if not (self._use_union_operator != self.data_type.use_union_operator and self.data_type.is_union): return None - parts = [dt.base_type_hint for dt in self.data_type.data_types if dt.base_type_hint] + parts = dict.fromkeys(dt.base_type_hint for dt in self.data_type.data_types if dt.base_type_hint).keys() if len(parts) > 1: return f"Union[{', '.join(parts)}]" return None diff --git a/src/datamodel_code_generator/parser/base.py b/src/datamodel_code_generator/parser/base.py index ecbad8e72..363f9e060 100644 --- a/src/datamodel_code_generator/parser/base.py +++ b/src/datamodel_code_generator/parser/base.py @@ -493,7 +493,7 @@ def add_model_path_to_list( return paths -def sort_data_models( # noqa: PLR0912, PLR0915 +def sort_data_models( # noqa: PLR0912, PLR0914, PLR0915 unsorted_data_models: list[DataModel], sorted_data_models: SortedDataModels | None = None, require_update_action_models: list[str] | None = None, @@ -502,8 +502,10 @@ def sort_data_models( # noqa: PLR0912, PLR0915 """Sort data models by dependency order for correct forward references.""" if sorted_data_models is None: sorted_data_models = OrderedDict() + if require_update_action_models is None: require_update_action_models = [] + sorted_model_count: int = len(sorted_data_models) unresolved_references: list[DataModel] = [] @@ -521,6 +523,7 @@ def sort_data_models( # noqa: PLR0912, PLR0915 add_model_path_to_list(require_update_action_models, model) else: unresolved_references.append(model) + if unresolved_references: if sorted_model_count != len(sorted_data_models) and recursion_count: try: @@ -534,6 +537,7 @@ def sort_data_models( # noqa: PLR0912, PLR0915 pass # sort on base_class dependency + seen_orderings: set[tuple[str, ...]] = set() while True: ordered_models: list[tuple[int, DataModel]] = [] # Build lookup dict for O(1) index access instead of O(n) list.index() @@ -552,6 +556,7 @@ def sort_data_models( # noqa: PLR0912, PLR0915 for b in model.base_classes if b.reference and b.reference.path in path_to_index ] + if indexes: ordered_models.append(( max(indexes), @@ -562,9 +567,19 @@ def sort_data_models( # noqa: PLR0912, PLR0915 -1, model, )) + sorted_unresolved_models = [m[1] for m in sorted(ordered_models, key=operator.itemgetter(0))] if sorted_unresolved_models == unresolved_references: break + + sig = tuple(m.path for m in sorted_unresolved_models) + if sig in seen_orderings: + # Base-class dependency order has no fixed point (e.g. cyclic inheritance with + # discriminators). Further iterations only permute the list; use stable order. + unresolved_references.sort(key=lambda m: m.path) + break + + seen_orderings.add(sig) unresolved_references = sorted_unresolved_models # circular reference @@ -578,16 +593,19 @@ def sort_data_models( # noqa: PLR0912, PLR0915 if update_action_parent: add_model_path_to_list(require_update_action_models, model) continue + if not unresolved_model - unsorted_data_model_names: sorted_data_models[model.path] = model add_model_path_to_list(require_update_action_models, model) continue + # unresolved unresolved_classes = ", ".join( f"[class: {item.path} references: {item.reference_classes}]" for item in unresolved_references ) msg = f"A Parser can not resolve classes: {unresolved_classes}." raise Exception(msg) # noqa: TRY002 + return unresolved_references, sorted_data_models, require_update_action_models @@ -1576,11 +1594,14 @@ def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915 discriminator_values: list[DiscriminatorValue] = [] def check_paths( - model: pydantic_model_v2.BaseModel | Reference, + model: pydantic_model_v2.BaseModel | Reference | None, mapping: dict[str, str], discriminator_values: list[DiscriminatorValue] = discriminator_values, ) -> None: """Validate discriminator mapping paths for a model.""" + if model is None: + return + for name, path in mapping.items(): if (model.path.split("#/")[-1] != path.split("#/")[-1]) and ( path.startswith("#/") or model.path[:-1] != path.split("/")[-1] @@ -1624,6 +1645,9 @@ def get_discriminator_field_value( if len(discriminator_values) == 0: for base_class in discriminator_model.base_classes: + if not base_class.reference: + continue + check_paths(base_class.reference, mapping) # ty: ignore if not discriminator_values: diff --git a/tests/data/expected/main/openapi/openapi_discriminated_oneof_allof_cycle.py b/tests/data/expected/main/openapi/openapi_discriminated_oneof_allof_cycle.py new file mode 100644 index 000000000..f4a2870df --- /dev/null +++ b/tests/data/expected/main/openapi/openapi_discriminated_oneof_allof_cycle.py @@ -0,0 +1,53 @@ +# generated by datamodel-codegen: +# filename: openapi_discriminated_oneof_allof_cycle.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Literal, Union + +from pydantic import BaseModel, Field, RootModel + + +class ASchema(BaseModel): + pass + + +class BSchema(BaseModel): + pass + + +class A1(BaseModel): + kind: Literal['a'] + + +class B1(BaseModel): + kind: Literal['b'] + + +class A(RootModel[Union["A2", "A3"]]): + root: Union["A2", "A3"] + + +class A2(A1): + pass + + +class A3(A1): + pass + + +class B(RootModel["B2"]): + root: "B2" + + +class B2(B1): + pass + + +class X(RootModel[A | B]): + root: A | B = Field(..., discriminator='kind') + + +A.model_rebuild() +B.model_rebuild() diff --git a/tests/data/openapi/openapi_discriminated_oneof_allof_cycle.json b/tests/data/openapi/openapi_discriminated_oneof_allof_cycle.json new file mode 100644 index 000000000..aa84a4f15 --- /dev/null +++ b/tests/data/openapi/openapi_discriminated_oneof_allof_cycle.json @@ -0,0 +1,74 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "", + "version": "" + }, + "paths": {}, + "components": { + "schemas": { + "ASchema": { + "type": "object", + "description": "Schema referenced by discriminator mapping." + }, + "BSchema": { + "type": "object", + "description": "Schema referenced by discriminator mapping." + }, + "X": { + "type": "object", + "discriminator": { + "propertyName": "kind", + "mapping": { + "a": "#/components/schemas/ASchema", + "b": "#/components/schemas/BSchema" + } + }, + "oneOf": [ + { + "$ref": "#/components/schemas/A" + }, + { + "$ref": "#/components/schemas/B" + } + ] + }, + "A": { + "allOf": [ + { + "$ref": "#/components/schemas/X" + }, + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "const": "a" + } + } + } + ] + }, + "B": { + "allOf": [ + { + "$ref": "#/components/schemas/X" + }, + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "const": "b" + } + } + } + ] + } + } + } +} diff --git a/tests/main/openapi/test_main_openapi.py b/tests/main/openapi/test_main_openapi.py index 65121180d..dd4a5a757 100644 --- a/tests/main/openapi/test_main_openapi.py +++ b/tests/main/openapi/test_main_openapi.py @@ -5316,3 +5316,18 @@ def test_main_reuse_model_with_type_alias(output_file: Path) -> None: "--use-type-alias", ], ) + + +@pytest.mark.timeout(30) +def test_main_openapi_discriminated_oneof_allof_cycle(output_file: Path) -> None: + """Discriminated oneOf with variants that allOf the parent (circular graph). + + Covers sort_data_models ordering for cyclic base dependencies and discriminator + handling (mapping + RootModel) on a minimal OpenAPI spec. + """ + run_main_and_assert( + input_path=OPEN_API_DATA_PATH / "openapi_discriminated_oneof_allof_cycle.json", + output_path=output_file, + input_file_type="openapi", + assert_func=assert_file_content, + ) diff --git a/uv.lock b/uv.lock index 55a81c343..a12766c67 100644 --- a/uv.lock +++ b/uv.lock @@ -857,6 +857,7 @@ dev = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "pytest-timeout" }, { name = "pytest-xdist" }, { name = "time-machine" }, { name = "twine" }, @@ -901,6 +902,7 @@ test = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "pytest-timeout" }, { name = "pytest-xdist" }, { name = "time-machine" }, { name = "watchfiles" }, @@ -914,6 +916,7 @@ type = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "pytest-timeout" }, { name = "pytest-xdist" }, { name = "time-machine" }, { name = "ty" }, @@ -975,11 +978,10 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.31.1" }, { name = "msgspec", specifier = ">=0.18" }, { name = "prek", specifier = ">=0.2.22" }, - { name = "pytest", specifier = ">=6.1" }, { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-cov", specifier = ">=2.12.1" }, { name = "pytest-cov", specifier = ">=5" }, { name = "pytest-mock", specifier = ">=3.14" }, + { name = "pytest-timeout", specifier = ">=2.4" }, { name = "pytest-xdist", specifier = ">=3.3.1" }, { name = "time-machine", specifier = ">=3.1" }, { name = "twine", specifier = ">=6.1" }, @@ -1009,11 +1011,10 @@ test = [ { name = "diff-cover", specifier = ">=9.7.2" }, { name = "inline-snapshot", specifier = ">=0.31.1" }, { name = "msgspec", specifier = ">=0.18" }, - { name = "pytest", specifier = ">=6.1" }, { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-cov", specifier = ">=2.12.1" }, { name = "pytest-cov", specifier = ">=5" }, { name = "pytest-mock", specifier = ">=3.14" }, + { name = "pytest-timeout", specifier = ">=2.4" }, { name = "pytest-xdist", specifier = ">=3.3.1" }, { name = "time-machine", specifier = ">=3.1" }, { name = "watchfiles", specifier = ">=1.1" }, @@ -1024,11 +1025,10 @@ type = [ { name = "diff-cover", specifier = ">=9.7.2" }, { name = "inline-snapshot", specifier = ">=0.31.1" }, { name = "msgspec", specifier = ">=0.18" }, - { name = "pytest", specifier = ">=6.1" }, { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-cov", specifier = ">=2.12.1" }, { name = "pytest-cov", specifier = ">=5" }, { name = "pytest-mock", specifier = ">=3.14" }, + { name = "pytest-timeout", specifier = ">=2.4" }, { name = "pytest-xdist", specifier = ">=3.3.1" }, { name = "time-machine", specifier = ">=3.1" }, { name = "ty", specifier = ">=0.0.8" }, @@ -2417,6 +2417,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0"