diff --git a/src/datamodel_code_generator/parser/base.py b/src/datamodel_code_generator/parser/base.py index ecbad8e72..f6a0132db 100644 --- a/src/datamodel_code_generator/parser/base.py +++ b/src/datamodel_code_generator/parser/base.py @@ -1521,17 +1521,33 @@ def _create_discriminator_data_type( ) -> DataType: """Create a data type for discriminator field, using enum literals if available.""" if enum_source: - enum_class_name = enum_source.reference.short_name - enum_member_literals: list[tuple[str, str]] = [] - for value in discriminator_values: - member = enum_source.find_member(value) - if member and member.field.name: - enum_member_literals.append((enum_class_name, member.field.name)) - else: # pragma: no cover - enum_member_literals.append((enum_class_name, str(value))) - data_type = self.data_type(enum_member_literals=enum_member_literals) - if enum_source.module_path != discriminator_model.module_path: # pragma: no cover - imports.append(Import.from_full_path(enum_source.name)) + if self.use_enum_values_in_discriminator: + enum_class_name = enum_source.reference.short_name + enum_member_literals: list[tuple[str, DiscriminatorValue]] = [] + for value in discriminator_values: + member = enum_source.find_member(value) + if member and member.field.name: + enum_member_literals.append((enum_class_name, member.field.name)) + else: # pragma: no cover + enum_member_literals.append((enum_class_name, str(value))) + data_type = self.data_type(enum_member_literals=enum_member_literals) + if enum_source.module_path != discriminator_model.module_path: # pragma: no cover + imports.append(Import.from_full_path(enum_source.name)) + else: + # According to OpenAPI specification, mapping discriminators are always string values. + # However, if the mapped object is an enum, we want to use the real enum value instead of + # the string value. + # See: https://swagger.io/specification/#options-for-mapping-values-to-schemas + # Fix: https://github.com/koxudaxi/datamodel-code-generator/issues/3073 + for i, value in enumerate(discriminator_values): + if member := enum_source.find_member(value): + match member.field.default: + case str(): + discriminator_values[i] = member.field.default.strip("'\"") + case _ if isinstance(member.field.default, DiscriminatorValue): + discriminator_values[i] = member.field.default + + data_type = self.data_type(literals=discriminator_values) else: data_type = self.data_type(literals=discriminator_values) return data_type @@ -1637,28 +1653,27 @@ def get_discriminator_field_value( raise RuntimeError(msg) enum_from_base: Enum | None = None - if self.use_enum_values_in_discriminator: - for base_class in discriminator_model.base_classes: - if not base_class.reference or not base_class.reference.source: # pragma: no cover - continue - base_model = base_class.reference.source - if not isinstance( # pragma: no cover - base_model, - ( - pydantic_model_v2.BaseModel, - dataclass_model.DataClass, - msgspec_model.Struct, - ), - ): + for base_class in discriminator_model.base_classes: + if not base_class.reference or not base_class.reference.source: # pragma: no cover + continue + base_model = base_class.reference.source + if not isinstance( # pragma: no cover + base_model, + ( + pydantic_model_v2.BaseModel, + dataclass_model.DataClass, + msgspec_model.Struct, + ), + ): + continue + for base_field in base_model.fields: # pragma: no branch + if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover continue - for base_field in base_model.fields: # pragma: no branch - if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover - continue - enum_from_base = base_field.data_type.find_source(Enum) - if enum_from_base: # pragma: no branch - break + enum_from_base = base_field.data_type.find_source(Enum) if enum_from_base: # pragma: no branch break + if enum_from_base: # pragma: no branch + break has_one_literal = False for discriminator_field in discriminator_model.fields: @@ -1690,11 +1705,7 @@ def get_discriminator_field_value( discriminator_field.extras["is_classvar"] = True break - enum_source: Enum | None = None - if self.use_enum_values_in_discriminator: - enum_source = ( # pragma: no cover - discriminator_field.data_type.find_source(Enum) or enum_from_base - ) + enum_source = discriminator_field.data_type.find_source(Enum) or enum_from_base for field_data_type in discriminator_field.data_type.all_data_types: if field_data_type.reference: # pragma: no cover diff --git a/tests/data/expected/main/openapi/discriminator/integer_mapping.py b/tests/data/expected/main/openapi/discriminator/integer_mapping.py index 753634d2f..31e529a9f 100644 --- a/tests/data/expected/main/openapi/discriminator/integer_mapping.py +++ b/tests/data/expected/main/openapi/discriminator/integer_mapping.py @@ -20,10 +20,11 @@ class Foo(BaseModel): class Kind1(IntEnum): integer_2 = 2 + integer_3 = 3 class Bar(BaseModel): - kind: Literal[2] + kind: Literal[2, 3] class Base(RootModel[Foo | Bar]): diff --git a/tests/data/openapi/discriminator_integer_mapping.yaml b/tests/data/openapi/discriminator_integer_mapping.yaml index da518bffa..cb6dbb77b 100644 --- a/tests/data/openapi/discriminator_integer_mapping.yaml +++ b/tests/data/openapi/discriminator_integer_mapping.yaml @@ -14,6 +14,7 @@ components: mapping: '1': '#/components/schemas/Foo' '2': '#/components/schemas/Bar' + '3': '#/components/schemas/Bar' Foo: type: object properties: @@ -30,5 +31,6 @@ components: type: integer enum: - 2 + - 3 required: - kind