diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 6af3362de7a..f655785410c 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -310,3 +310,38 @@ def test_split_tensor_vgf_quant(test_data: Tuple): quantize=True, ) pipeline.run() + + +a16w8_split_test_parameters = { + "a16w8_1d_split_2": lambda: (torch.rand(10), 2, 0), + "a16w8_2d_split_4": lambda: (torch.rand(8, 4), 4, 0), + "a16w8_3d_split_4": lambda: (torch.rand(4, 4, 8), 4, 2), +} + + +@common.parametrize("test_data", a16w8_split_test_parameters) +@common.XfailIfNoCorstone300 +def test_split_a16w8_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + Split(), + test_data(), + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + symmetric_io_quantization=True, + ) + pipeline.run() + + +@common.parametrize("test_data", a16w8_split_test_parameters) +@common.XfailIfNoCorstone320 +def test_split_a16w8_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + Split(), + test_data(), + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + symmetric_io_quantization=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 35aa2857d3b..74ceb4b557d 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -32,6 +32,11 @@ class Var(torch.nn.Module): ), } + test_parameters_ethosu = { + "var_4d_keep_dim_0_correction": lambda: (torch.randn(1, 50, 10, 20), True, 0), + "var_4d_keep_dim_1_correction": lambda: (torch.randn(1, 30, 15, 20), True, 1), + } + def __init__(self, keepdim: bool = True, correction: int = 0): super().__init__() self.keepdim = keepdim diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index bd062188338..0a3faa6a074 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -43,6 +43,7 @@ def define_arm_tests(): "ops/test_conv1d.py", "ops/test_gelu.py", "ops/test_bmm.py", + "ops/test_split.py", ] # Quantization