Skip to content
21 changes: 14 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6155,9 +6155,13 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I
result = self
indices = op.Constant(value_int=0)
else:
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
result = op.ReduceMax(self, dims, keepdims=keepdim)
indices = op.ArgMax(self, axis=dim, keepdims=keepdim)
values, indices = op.TopK(self, K=[1], axis=dim, largest=1, sorted=0)
if keepdim:
result = values
else:
squeeze_axe = op.Constant(value_ints=[dim])
result = op.Squeeze(values, axes=squeeze_axe)
indices = op.Squeeze(indices, axes=squeeze_axe)
return result, indices


Expand Down Expand Up @@ -6242,10 +6246,13 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T
result = self
indices = op.Constant(value_int=0)
else:
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
result = op.ReduceMin(self, dims, keepdims=keepdim)
indices = op.ArgMin(self, axis=dim, keepdims=keepdim)

values, indices = op.TopK(self, K=[1], axis=dim, largest=0, sorted=0)
if keepdim:
result = values
else:
squeeze_axe = op.Constant(value_ints=[dim])
result = op.Squeeze(values, axes=squeeze_axe)
indices = op.Squeeze(indices, axes=squeeze_axe)
return result, indices


Expand Down
52 changes: 52 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,58 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

def test_max_dim_negative_dim_squeeze_stability(self):
"""Ensure max.dim(dim=-1, keepdim=False) exports and runs correctly.

TopK + Squeeze(axes=[dim]) receives dim=-1. Validates that ORT handles
negative axis in Squeeze and output shape matches PyTorch.
"""

class Model(torch.nn.Module):
def forward(self, x):
return torch.max(x, dim=-1, keepdim=False)

x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(
Model(), (x,), dynamo=True, verbose=False
)
_testing.assert_onnx_program(onnx_program)

def test_min_dim_negative_dim_squeeze_stability(self):
"""Ensure min.dim(dim=-1, keepdim=False) exports and runs correctly.

Same as max_dim_negative_dim: TopK + Squeeze(axes=[dim]) with dim=-1.
"""

class Model(torch.nn.Module):
def forward(self, x):
return torch.min(x, dim=-1, keepdim=False)

x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(
Model(), (x,), dynamo=True, verbose=False
)
_testing.assert_onnx_program(onnx_program)

def test_max_dim_chained_reduction(self):
"""Ensure x.max(dim=1).values.max(dim=0) exports and runs correctly.

Validates that TopK -> Squeeze -> next TopK -> Squeeze shape flow
is correct when chaining max.dim calls.
"""

class Model(torch.nn.Module):
def forward(self, x):
v1, _ = x.max(dim=1, keepdim=False)
v2, _ = v1.max(dim=0, keepdim=False)
return v2

x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(
Model(), (x,), dynamo=True, verbose=False
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading