Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions catalog/internal/catalog/modelcatalog/service/catalog_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/kubeflow/model-registry/catalog/internal/catalog/modelcatalog/models"
"github.com/kubeflow/model-registry/internal/apiutils"
Expand Down Expand Up @@ -1023,6 +1024,274 @@ func TestCatalogModelRepository(t *testing.T) {
}
})

t.Run("TestNameOrderingCaseInsensitive", func(t *testing.T) {
// Create test models with mixed case names to verify case-insensitive ordering
testModels := []string{
"ALPHA-uppercase",
"alpha-lowercase",
"Alpha-Capitalized",
"BETA-uppercase",
"beta-lowercase",
"Beta-Capitalized",
}

for _, name := range testModels {
catalogModel := &models.CatalogModelImpl{
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of(name),
ExternalID: apiutils.Of(name + "-case-test"),
},
}

_, err := repo.Save(catalogModel)
require.NoError(t, err)
}

// Test NAME ordering ASC (case-insensitive)
listOptions := models.CatalogModelListOptions{
Pagination: dbmodels.Pagination{
OrderBy: apiutils.Of("NAME"),
SortOrder: apiutils.Of("ASC"),
},
}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)

// Extract our test model names from results
var foundNames []string
for _, model := range result.Items {
name := *model.GetAttributes().Name
if strings.HasSuffix(name, "-uppercase") || strings.HasSuffix(name, "-lowercase") || strings.HasSuffix(name, "-Capitalized") {
foundNames = append(foundNames, name)
}
}

require.GreaterOrEqual(t, len(foundNames), 6, "Should find all case test models")

// Find positions of alpha and beta variants
var alphaPositions, betaPositions []int
for i, name := range foundNames {
lowerName := strings.ToLower(name)
if strings.HasPrefix(lowerName, "alpha") {
alphaPositions = append(alphaPositions, i)
} else if strings.HasPrefix(lowerName, "beta") {
betaPositions = append(betaPositions, i)
}
}

// In case-insensitive ordering, all alpha variants should be grouped together
// (consecutive positions) and come before all beta variants
require.Len(t, alphaPositions, 3, "Should find 3 alpha variants")
require.Len(t, betaPositions, 3, "Should find 3 beta variants")

// Verify alpha variants are consecutive (grouped together)
alphaSpread := alphaPositions[len(alphaPositions)-1] - alphaPositions[0]
assert.Equal(t, 2, alphaSpread, "Alpha variants should be grouped together (spread=2)")

// Verify beta variants are consecutive (grouped together)
betaSpread := betaPositions[len(betaPositions)-1] - betaPositions[0]
assert.Equal(t, 2, betaSpread, "Beta variants should be grouped together (spread=2)")

// Verify all alphas come before all betas
maxAlphaPos := alphaPositions[len(alphaPositions)-1]
minBetaPos := betaPositions[0]
assert.Less(t, maxAlphaPos, minBetaPos, "All alpha variants should come before beta variants")
})

t.Run("TestNameOrderingRedHatAIModels", func(t *testing.T) {
// Create test models with realistic RedHatAI model names
testModels := []string{
"RedHatAI/whisper-large-v3-turbo-quantized.w4a16",
"RedHatAI/Voxtral-Mini-3B-2507-FP8-dynamic",
"RedHatAI/Qwen3-Coder-480B-A35B-Instruct-FP8",
"RedHatAI/Qwen3-8B-FP8-dynamic",
"RedHatAI/Qwen2.5-7B-Instruct-quantized.w8a8",
"RedHatAI/Qwen2.5-7B-Instruct-quantized.w4a16",
"RedHatAI/Qwen2.5-7B-Instruct-FP8-dynamic",
"RedHatAI/Qwen2.5-7B-Instruct",
}

for _, name := range testModels {
catalogModel := &models.CatalogModelImpl{
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of(name),
ExternalID: apiutils.Of(strings.ReplaceAll(name, "/", "-") + "-ext"),
},
}

_, err := repo.Save(catalogModel)
require.NoError(t, err)
}

// Test NAME ordering ASC
listOptions := models.CatalogModelListOptions{
Pagination: dbmodels.Pagination{
OrderBy: apiutils.Of("NAME"),
SortOrder: apiutils.Of("ASC"),
},
}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)

// Extract our RedHatAI test model names from results
var foundNames []string
for _, model := range result.Items {
name := *model.GetAttributes().Name
if strings.HasPrefix(name, "RedHatAI/") {
foundNames = append(foundNames, name)
}
}

require.GreaterOrEqual(t, len(foundNames), 8, "Should find all RedHatAI test models")

// Verify Qwen2.5 models come before Qwen3 models (case-insensitive alphabetical)
var qwen25Positions, qwen3Positions []int
for i, name := range foundNames {
if strings.Contains(name, "Qwen2.5") {
qwen25Positions = append(qwen25Positions, i)
} else if strings.Contains(name, "Qwen3") {
qwen3Positions = append(qwen3Positions, i)
}
}

// Qwen2.5 should come before Qwen3
if len(qwen25Positions) > 0 && len(qwen3Positions) > 0 {
maxQwen25Pos := qwen25Positions[len(qwen25Positions)-1]
minQwen3Pos := qwen3Positions[0]
assert.Less(t, maxQwen25Pos, minQwen3Pos, "Qwen2.5 models should come before Qwen3 models")
}

// Verify Voxtral comes after Qwen (V > Q in alphabet)
voxtralPos := -1
for i, name := range foundNames {
if strings.Contains(name, "Voxtral") {
voxtralPos = i
break
}
}

if voxtralPos != -1 && len(qwen3Positions) > 0 {
maxQwen3Pos := qwen3Positions[len(qwen3Positions)-1]
assert.Greater(t, voxtralPos, maxQwen3Pos, "Voxtral should come after Qwen models")
}

// Verify whisper comes after Voxtral (w > v in alphabet)
whisperPos := -1
for i, name := range foundNames {
if strings.Contains(name, "whisper") {
whisperPos = i
break
}
}

if whisperPos != -1 && voxtralPos != -1 {
assert.Greater(t, whisperPos, voxtralPos, "whisper should come after Voxtral (case-insensitive)")
}

// Log the actual order for visibility
t.Logf("RedHatAI models in ASC order:")
for i, name := range foundNames {
t.Logf(" %d: %s", i, name)
}
})

t.Run("TestNameOrderingCaseInsensitivePagination", func(t *testing.T) {
// Create test models with unique prefix that would be split across pages if case-sensitive
// but should be grouped together if case-insensitive
uniquePrefix := fmt.Sprintf("cipage-%d-", time.Now().UnixNano())
testModels := []string{
uniquePrefix + "AAA",
uniquePrefix + "aaa",
uniquePrefix + "Aaa",
uniquePrefix + "BBB",
uniquePrefix + "bbb",
uniquePrefix + "Bbb",
}

for _, name := range testModels {
catalogModel := &models.CatalogModelImpl{
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of(name),
ExternalID: apiutils.Of(name + "-ext"),
},
}

_, err := repo.Save(catalogModel)
require.NoError(t, err)
}

// Test pagination with small page size
listOptions := models.CatalogModelListOptions{
Pagination: dbmodels.Pagination{
OrderBy: apiutils.Of("NAME"),
SortOrder: apiutils.Of("ASC"),
PageSize: apiutils.Of(int32(2)),
},
}

// Collect all pagination test models across pages
var allPaginatedModels []string
currentToken := (*string)(nil)

for pageCount := 0; pageCount < 20; pageCount++ {
if currentToken != nil {
listOptions.Pagination.NextPageToken = currentToken
}

page, err := repo.List(listOptions)
require.NoError(t, err)

for _, model := range page.Items {
name := *model.GetAttributes().Name
if strings.HasPrefix(name, uniquePrefix) {
allPaginatedModels = append(allPaginatedModels, name)
}
}

if page.NextPageToken == "" {
break
}
currentToken = &page.NextPageToken
}

require.GreaterOrEqual(t, len(allPaginatedModels), 6, "Should find all pagination case test models")

// Verify that case variants are grouped together
// Find positions of AAA/aaa/Aaa variants
var aaaPositions, bbbPositions []int
for i, name := range allPaginatedModels {
lowerName := strings.ToLower(name)
if strings.HasSuffix(lowerName, "aaa") {
aaaPositions = append(aaaPositions, i)
} else if strings.HasSuffix(lowerName, "bbb") {
bbbPositions = append(bbbPositions, i)
}
}

// All AAA variants should be consecutive (case-insensitive grouping)
if len(aaaPositions) >= 3 {
aaaSpread := aaaPositions[len(aaaPositions)-1] - aaaPositions[0]
assert.Equal(t, 2, aaaSpread, "AAA variants should be grouped together across pagination")
}

// All BBB variants should be consecutive
if len(bbbPositions) >= 3 {
bbbSpread := bbbPositions[len(bbbPositions)-1] - bbbPositions[0]
assert.Equal(t, 2, bbbSpread, "BBB variants should be grouped together across pagination")
}

// All AAA should come before all BBB
if len(aaaPositions) > 0 && len(bbbPositions) > 0 {
maxAaaPos := aaaPositions[len(aaaPositions)-1]
minBbbPos := bbbPositions[0]
assert.Less(t, maxAaaPos, minBbbPos, "AAA variants should come before BBB variants")
}

t.Logf("Pagination case models in order: %v", allPaginatedModels)
})

t.Run("TestDeleteBySource", func(t *testing.T) {
// Setup: Create models with different source IDs
sourceID1 := "test_source_1"
Expand Down
11 changes: 6 additions & 5 deletions catalog/internal/db/pagination/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func CreateNamePaginationToken(entityID int32, name *string) string {
// ApplyNameOrdering applies NAME-based ordering with cursor pagination to a query.
// This handles the catalog-specific NAME ordering which requires string comparison
// in WHERE clauses (not integer casting like standard pagination).
// The ordering is case-insensitive using LOWER() for consistent alphabetical sorting.
//
// Parameters:
// - query: The GORM query to modify
Expand All @@ -46,8 +47,8 @@ func ApplyNameOrdering(query *gorm.DB, tableName string, sortOrder string, nextP
order = "DESC"
}

// Apply name-based ordering with ID as tie-breaker
query = query.Order(fmt.Sprintf("%s.name %s, %s.id ASC", tableName, order, tableName))
// Apply case-insensitive name-based ordering with ID as tie-breaker
query = query.Order(fmt.Sprintf("LOWER(%s.name) %s, %s.id ASC", tableName, order, tableName))

// Handle cursor-based pagination for NAME
if nextPageToken != "" {
Expand All @@ -56,13 +57,13 @@ func ApplyNameOrdering(query *gorm.DB, tableName string, sortOrder string, nextP
_ = query.AddError(fmt.Errorf("invalid nextPageToken: %w", err))
return query
}
// Cursor pagination based on name (string comparison)
// Cursor pagination based on name (case-insensitive string comparison)
cmp := ">"
if order == "DESC" {
cmp = "<"
}
// Use proper string comparison with name and ID as tie-breaker
query = query.Where(fmt.Sprintf("(%s.name %s ? OR (%s.name = ? AND %s.id > ?))", tableName, cmp, tableName, tableName),
// Use LOWER() for case-insensitive comparison with name and ID as tie-breaker
query = query.Where(fmt.Sprintf("(LOWER(%s.name) %s LOWER(?) OR (LOWER(%s.name) = LOWER(?) AND %s.id > ?))", tableName, cmp, tableName, tableName),
cursor.Value, cursor.Value, cursor.ID)
}

Expand Down
Loading