From fc0a62fd67a732d75d41e27482a9088dd5832fac Mon Sep 17 00:00:00 2001 From: Alessio Pragliola Date: Wed, 17 Dec 2025 21:30:46 +0100 Subject: [PATCH] fix: make NAME ordering case insensitive Signed-off-by: Alessio Pragliola --- .../service/catalog_model_test.go | 269 ++++++++++++++++++ catalog/internal/db/pagination/pagination.go | 11 +- 2 files changed, 275 insertions(+), 5 deletions(-) diff --git a/catalog/internal/catalog/modelcatalog/service/catalog_model_test.go b/catalog/internal/catalog/modelcatalog/service/catalog_model_test.go index 4549fd6b5e..0fcefbce6a 100644 --- a/catalog/internal/catalog/modelcatalog/service/catalog_model_test.go +++ b/catalog/internal/catalog/modelcatalog/service/catalog_model_test.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "testing" + "time" "github.com/kubeflow/model-registry/catalog/internal/catalog/modelcatalog/models" "github.com/kubeflow/model-registry/internal/apiutils" @@ -1023,6 +1024,274 @@ func TestCatalogModelRepository(t *testing.T) { } }) + t.Run("TestNameOrderingCaseInsensitive", func(t *testing.T) { + // Create test models with mixed case names to verify case-insensitive ordering + testModels := []string{ + "ALPHA-uppercase", + "alpha-lowercase", + "Alpha-Capitalized", + "BETA-uppercase", + "beta-lowercase", + "Beta-Capitalized", + } + + for _, name := range testModels { + catalogModel := &models.CatalogModelImpl{ + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(name), + ExternalID: apiutils.Of(name + "-case-test"), + }, + } + + _, err := repo.Save(catalogModel) + require.NoError(t, err) + } + + // Test NAME ordering ASC (case-insensitive) + listOptions := models.CatalogModelListOptions{ + Pagination: dbmodels.Pagination{ + OrderBy: apiutils.Of("NAME"), + SortOrder: apiutils.Of("ASC"), + }, + } + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + + // Extract our test model names from results + var foundNames []string + for _, model := range result.Items { + name := *model.GetAttributes().Name + if strings.HasSuffix(name, "-uppercase") || strings.HasSuffix(name, "-lowercase") || strings.HasSuffix(name, "-Capitalized") { + foundNames = append(foundNames, name) + } + } + + require.GreaterOrEqual(t, len(foundNames), 6, "Should find all case test models") + + // Find positions of alpha and beta variants + var alphaPositions, betaPositions []int + for i, name := range foundNames { + lowerName := strings.ToLower(name) + if strings.HasPrefix(lowerName, "alpha") { + alphaPositions = append(alphaPositions, i) + } else if strings.HasPrefix(lowerName, "beta") { + betaPositions = append(betaPositions, i) + } + } + + // In case-insensitive ordering, all alpha variants should be grouped together + // (consecutive positions) and come before all beta variants + require.Len(t, alphaPositions, 3, "Should find 3 alpha variants") + require.Len(t, betaPositions, 3, "Should find 3 beta variants") + + // Verify alpha variants are consecutive (grouped together) + alphaSpread := alphaPositions[len(alphaPositions)-1] - alphaPositions[0] + assert.Equal(t, 2, alphaSpread, "Alpha variants should be grouped together (spread=2)") + + // Verify beta variants are consecutive (grouped together) + betaSpread := betaPositions[len(betaPositions)-1] - betaPositions[0] + assert.Equal(t, 2, betaSpread, "Beta variants should be grouped together (spread=2)") + + // Verify all alphas come before all betas + maxAlphaPos := alphaPositions[len(alphaPositions)-1] + minBetaPos := betaPositions[0] + assert.Less(t, maxAlphaPos, minBetaPos, "All alpha variants should come before beta variants") + }) + + t.Run("TestNameOrderingRedHatAIModels", func(t *testing.T) { + // Create test models with realistic RedHatAI model names + testModels := []string{ + "RedHatAI/whisper-large-v3-turbo-quantized.w4a16", + "RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic", + "RedHatAI/Qwen3-Coder-480B-A35B-Instruct-FP8", + "RedHatAI/Qwen3-8B-FP8-dynamic", + "RedHatAI/Qwen2.5-7B-Instruct-quantized.w8a8", + "RedHatAI/Qwen2.5-7B-Instruct-quantized.w4a16", + "RedHatAI/Qwen2.5-7B-Instruct-FP8-dynamic", + "RedHatAI/Qwen2.5-7B-Instruct", + } + + for _, name := range testModels { + catalogModel := &models.CatalogModelImpl{ + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(name), + ExternalID: apiutils.Of(strings.ReplaceAll(name, "/", "-") + "-ext"), + }, + } + + _, err := repo.Save(catalogModel) + require.NoError(t, err) + } + + // Test NAME ordering ASC + listOptions := models.CatalogModelListOptions{ + Pagination: dbmodels.Pagination{ + OrderBy: apiutils.Of("NAME"), + SortOrder: apiutils.Of("ASC"), + }, + } + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + + // Extract our RedHatAI test model names from results + var foundNames []string + for _, model := range result.Items { + name := *model.GetAttributes().Name + if strings.HasPrefix(name, "RedHatAI/") { + foundNames = append(foundNames, name) + } + } + + require.GreaterOrEqual(t, len(foundNames), 8, "Should find all RedHatAI test models") + + // Verify Qwen2.5 models come before Qwen3 models (case-insensitive alphabetical) + var qwen25Positions, qwen3Positions []int + for i, name := range foundNames { + if strings.Contains(name, "Qwen2.5") { + qwen25Positions = append(qwen25Positions, i) + } else if strings.Contains(name, "Qwen3") { + qwen3Positions = append(qwen3Positions, i) + } + } + + // Qwen2.5 should come before Qwen3 + if len(qwen25Positions) > 0 && len(qwen3Positions) > 0 { + maxQwen25Pos := qwen25Positions[len(qwen25Positions)-1] + minQwen3Pos := qwen3Positions[0] + assert.Less(t, maxQwen25Pos, minQwen3Pos, "Qwen2.5 models should come before Qwen3 models") + } + + // Verify Voxtral comes after Qwen (V > Q in alphabet) + voxtralPos := -1 + for i, name := range foundNames { + if strings.Contains(name, "Voxtral") { + voxtralPos = i + break + } + } + + if voxtralPos != -1 && len(qwen3Positions) > 0 { + maxQwen3Pos := qwen3Positions[len(qwen3Positions)-1] + assert.Greater(t, voxtralPos, maxQwen3Pos, "Voxtral should come after Qwen models") + } + + // Verify whisper comes after Voxtral (w > v in alphabet) + whisperPos := -1 + for i, name := range foundNames { + if strings.Contains(name, "whisper") { + whisperPos = i + break + } + } + + if whisperPos != -1 && voxtralPos != -1 { + assert.Greater(t, whisperPos, voxtralPos, "whisper should come after Voxtral (case-insensitive)") + } + + // Log the actual order for visibility + t.Logf("RedHatAI models in ASC order:") + for i, name := range foundNames { + t.Logf(" %d: %s", i, name) + } + }) + + t.Run("TestNameOrderingCaseInsensitivePagination", func(t *testing.T) { + // Create test models with unique prefix that would be split across pages if case-sensitive + // but should be grouped together if case-insensitive + uniquePrefix := fmt.Sprintf("cipage-%d-", time.Now().UnixNano()) + testModels := []string{ + uniquePrefix + "AAA", + uniquePrefix + "aaa", + uniquePrefix + "Aaa", + uniquePrefix + "BBB", + uniquePrefix + "bbb", + uniquePrefix + "Bbb", + } + + for _, name := range testModels { + catalogModel := &models.CatalogModelImpl{ + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(name), + ExternalID: apiutils.Of(name + "-ext"), + }, + } + + _, err := repo.Save(catalogModel) + require.NoError(t, err) + } + + // Test pagination with small page size + listOptions := models.CatalogModelListOptions{ + Pagination: dbmodels.Pagination{ + OrderBy: apiutils.Of("NAME"), + SortOrder: apiutils.Of("ASC"), + PageSize: apiutils.Of(int32(2)), + }, + } + + // Collect all pagination test models across pages + var allPaginatedModels []string + currentToken := (*string)(nil) + + for pageCount := 0; pageCount < 20; pageCount++ { + if currentToken != nil { + listOptions.Pagination.NextPageToken = currentToken + } + + page, err := repo.List(listOptions) + require.NoError(t, err) + + for _, model := range page.Items { + name := *model.GetAttributes().Name + if strings.HasPrefix(name, uniquePrefix) { + allPaginatedModels = append(allPaginatedModels, name) + } + } + + if page.NextPageToken == "" { + break + } + currentToken = &page.NextPageToken + } + + require.GreaterOrEqual(t, len(allPaginatedModels), 6, "Should find all pagination case test models") + + // Verify that case variants are grouped together + // Find positions of AAA/aaa/Aaa variants + var aaaPositions, bbbPositions []int + for i, name := range allPaginatedModels { + lowerName := strings.ToLower(name) + if strings.HasSuffix(lowerName, "aaa") { + aaaPositions = append(aaaPositions, i) + } else if strings.HasSuffix(lowerName, "bbb") { + bbbPositions = append(bbbPositions, i) + } + } + + // All AAA variants should be consecutive (case-insensitive grouping) + if len(aaaPositions) >= 3 { + aaaSpread := aaaPositions[len(aaaPositions)-1] - aaaPositions[0] + assert.Equal(t, 2, aaaSpread, "AAA variants should be grouped together across pagination") + } + + // All BBB variants should be consecutive + if len(bbbPositions) >= 3 { + bbbSpread := bbbPositions[len(bbbPositions)-1] - bbbPositions[0] + assert.Equal(t, 2, bbbSpread, "BBB variants should be grouped together across pagination") + } + + // All AAA should come before all BBB + if len(aaaPositions) > 0 && len(bbbPositions) > 0 { + maxAaaPos := aaaPositions[len(aaaPositions)-1] + minBbbPos := bbbPositions[0] + assert.Less(t, maxAaaPos, minBbbPos, "AAA variants should come before BBB variants") + } + + t.Logf("Pagination case models in order: %v", allPaginatedModels) + }) + t.Run("TestDeleteBySource", func(t *testing.T) { // Setup: Create models with different source IDs sourceID1 := "test_source_1" diff --git a/catalog/internal/db/pagination/pagination.go b/catalog/internal/db/pagination/pagination.go index b044c04ec9..50e2e8d584 100644 --- a/catalog/internal/db/pagination/pagination.go +++ b/catalog/internal/db/pagination/pagination.go @@ -30,6 +30,7 @@ func CreateNamePaginationToken(entityID int32, name *string) string { // ApplyNameOrdering applies NAME-based ordering with cursor pagination to a query. // This handles the catalog-specific NAME ordering which requires string comparison // in WHERE clauses (not integer casting like standard pagination). +// The ordering is case-insensitive using LOWER() for consistent alphabetical sorting. // // Parameters: // - query: The GORM query to modify @@ -46,8 +47,8 @@ func ApplyNameOrdering(query *gorm.DB, tableName string, sortOrder string, nextP order = "DESC" } - // Apply name-based ordering with ID as tie-breaker - query = query.Order(fmt.Sprintf("%s.name %s, %s.id ASC", tableName, order, tableName)) + // Apply case-insensitive name-based ordering with ID as tie-breaker + query = query.Order(fmt.Sprintf("LOWER(%s.name) %s, %s.id ASC", tableName, order, tableName)) // Handle cursor-based pagination for NAME if nextPageToken != "" { @@ -56,13 +57,13 @@ func ApplyNameOrdering(query *gorm.DB, tableName string, sortOrder string, nextP _ = query.AddError(fmt.Errorf("invalid nextPageToken: %w", err)) return query } - // Cursor pagination based on name (string comparison) + // Cursor pagination based on name (case-insensitive string comparison) cmp := ">" if order == "DESC" { cmp = "<" } - // Use proper string comparison with name and ID as tie-breaker - query = query.Where(fmt.Sprintf("(%s.name %s ? OR (%s.name = ? AND %s.id > ?))", tableName, cmp, tableName, tableName), + // Use LOWER() for case-insensitive comparison with name and ID as tie-breaker + query = query.Where(fmt.Sprintf("(LOWER(%s.name) %s LOWER(?) OR (LOWER(%s.name) = LOWER(?) AND %s.id > ?))", tableName, cmp, tableName, tableName), cursor.Value, cursor.Value, cursor.ID) }