diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index 92759a1d8f..793e350c25 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -332,7 +332,10 @@ def forward(self, bs, num_query, _ = query.shape bs, num_value, _ = value.shape - assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != num_value: + raise ValueError('The sequence length of `value` must be equal to' + 'size of the flattened features maps summed over' + 'all levels.') value = self.value_proj(value) if key_padding_mask is not None: