Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 46 additions & 35 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,17 +1521,33 @@ def _create_discriminator_data_type(
) -> DataType:
"""Create a data type for discriminator field, using enum literals if available."""
if enum_source:
enum_class_name = enum_source.reference.short_name
enum_member_literals: list[tuple[str, str]] = []
for value in discriminator_values:
member = enum_source.find_member(value)
if member and member.field.name:
enum_member_literals.append((enum_class_name, member.field.name))
else: # pragma: no cover
enum_member_literals.append((enum_class_name, str(value)))
data_type = self.data_type(enum_member_literals=enum_member_literals)
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
imports.append(Import.from_full_path(enum_source.name))
if self.use_enum_values_in_discriminator:
enum_class_name = enum_source.reference.short_name
enum_member_literals: list[tuple[str, DiscriminatorValue]] = []
for value in discriminator_values:
member = enum_source.find_member(value)
if member and member.field.name:
enum_member_literals.append((enum_class_name, member.field.name))
else: # pragma: no cover
enum_member_literals.append((enum_class_name, str(value)))
data_type = self.data_type(enum_member_literals=enum_member_literals)
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
imports.append(Import.from_full_path(enum_source.name))
else:
# According to OpenAPI specification, mapping discriminators are always string values.
# However, if the mapped object is an enum, we want to use the real enum value instead of
# the string value.
# See: https://swagger.io/specification/#options-for-mapping-values-to-schemas
# Fix: https://github.com/koxudaxi/datamodel-code-generator/issues/3073
for i, value in enumerate(discriminator_values):
if member := enum_source.find_member(value):
match member.field.default:
case str():
discriminator_values[i] = member.field.default.strip("'\"")
case _ if isinstance(member.field.default, DiscriminatorValue):
discriminator_values[i] = member.field.default

data_type = self.data_type(literals=discriminator_values)
else:
data_type = self.data_type(literals=discriminator_values)
return data_type
Expand Down Expand Up @@ -1637,28 +1653,27 @@ def get_discriminator_field_value(
raise RuntimeError(msg)

enum_from_base: Enum | None = None
if self.use_enum_values_in_discriminator:
for base_class in discriminator_model.base_classes:
if not base_class.reference or not base_class.reference.source: # pragma: no cover
continue
base_model = base_class.reference.source
if not isinstance( # pragma: no cover
base_model,
(
pydantic_model_v2.BaseModel,
dataclass_model.DataClass,
msgspec_model.Struct,
),
):
for base_class in discriminator_model.base_classes:
if not base_class.reference or not base_class.reference.source: # pragma: no cover
continue
base_model = base_class.reference.source
if not isinstance( # pragma: no cover
base_model,
(
pydantic_model_v2.BaseModel,
dataclass_model.DataClass,
msgspec_model.Struct,
),
):
continue
for base_field in base_model.fields: # pragma: no branch
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
continue
for base_field in base_model.fields: # pragma: no branch
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
continue
enum_from_base = base_field.data_type.find_source(Enum)
if enum_from_base: # pragma: no branch
break
enum_from_base = base_field.data_type.find_source(Enum)
if enum_from_base: # pragma: no branch
break
if enum_from_base: # pragma: no branch
break

has_one_literal = False
for discriminator_field in discriminator_model.fields:
Expand Down Expand Up @@ -1690,11 +1705,7 @@ def get_discriminator_field_value(
discriminator_field.extras["is_classvar"] = True
break

enum_source: Enum | None = None
if self.use_enum_values_in_discriminator:
enum_source = ( # pragma: no cover
discriminator_field.data_type.find_source(Enum) or enum_from_base
)
enum_source = discriminator_field.data_type.find_source(Enum) or enum_from_base

for field_data_type in discriminator_field.data_type.all_data_types:
if field_data_type.reference: # pragma: no cover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ class Foo(BaseModel):

class Kind1(IntEnum):
integer_2 = 2
integer_3 = 3


class Bar(BaseModel):
kind: Literal[2]
kind: Literal[2, 3]


class Base(RootModel[Foo | Bar]):
Expand Down
2 changes: 2 additions & 0 deletions tests/data/openapi/discriminator_integer_mapping.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ components:
mapping:
'1': '#/components/schemas/Foo'
'2': '#/components/schemas/Bar'
'3': '#/components/schemas/Bar'
Foo:
type: object
properties:
Expand All @@ -30,5 +31,6 @@ components:
type: integer
enum:
- 2
- 3
required:
- kind
Loading