diff --git a/docker-compose.yml b/docker-compose.yml index 5e33051..cb2154b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: PORT: 8080 TOKEN_HOUR_LIFESPAN: 1 REFRESH_TOKEN_HOUR_LIFESPAN: 720 - READ_API_AUTHENTICATION_ENABLED: false + READ_API_AUTHENTICATION_ENABLED: "false" ports: - "8080:8080" depends_on: diff --git a/pkg/api/licenses.go b/pkg/api/licenses.go index f71822c..6c9d12a 100644 --- a/pkg/api/licenses.go +++ b/pkg/api/licenses.go @@ -314,7 +314,9 @@ func CreateLicense(c *gin.Context) { } // Send notification email about license creation if email.Email != nil { - email.NotifyLicenseCreated(*lic.User.UserEmail, *lic.User.UserName, *lic.Shortname) + if lic.User.UserEmail != nil && lic.User.UserName != nil { + email.NotifyLicenseCreated(*lic.User.UserEmail, *lic.User.UserName, *lic.Shortname) + } } else { logger.LogInfo("Email service is not enabled; skipping notification email sending") } @@ -488,7 +490,9 @@ func UpdateLicense(c *gin.Context) { } // Send notification email about license update if email.Email != nil { - email.NotifyLicenseUpdated(*newLicense.User.UserEmail, *newLicense.User.UserName, *newLicense.Shortname) + if newLicense.User.UserEmail != nil && newLicense.User.UserName != nil { + email.NotifyLicenseUpdated(*newLicense.User.UserEmail, *newLicense.User.UserName, *newLicense.Shortname) + } } else { logger.LogInfo("Email service is not enabled; skipping notification email sending") } @@ -852,14 +856,22 @@ func getSimilarLicenses(c *gin.Context) { return } var results []models.SimilarLicense - utils.SetSimilarityThreshold() + threshold := utils.GetSimilarityThreshold() query := ` SELECT rf_id, rf_shortname, rf_text, similarity(rf_text, ?) AS similarity FROM license_dbs WHERE rf_text % ? ORDER BY similarity DESC ` - if err := db.DB.Raw(query, req.Text, req.Text).Scan(&results).Error; err != nil { + // SET LOCAL scopes the threshold to this transaction only, preventing race conditions + // when multiple requests adjust the threshold concurrently (unlike the previous SET which + // was session-scoped and leaked across concurrent connections in the pool). + if err := db.DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec(fmt.Sprintf("SET LOCAL pg_trgm.similarity_threshold = %.4f", threshold)).Error; err != nil { + return err + } + return tx.Raw(query, req.Text, req.Text).Scan(&results).Error + }); err != nil { c.JSON(http.StatusBadRequest, models.LicenseError{ Status: http.StatusBadRequest, Message: "Database query failed", diff --git a/pkg/api/obligations.go b/pkg/api/obligations.go index b7130d9..5c7ce06 100644 --- a/pkg/api/obligations.go +++ b/pkg/api/obligations.go @@ -1034,14 +1034,22 @@ func getSimilarObligations(c *gin.Context) { return } var results []models.SimilarObligation - utils.SetSimilarityThreshold() + threshold := utils.GetSimilarityThreshold() rawQuery := ` - SELECT id, topic,text, similarity(text, ?) AS similarity + SELECT id, topic, text, similarity(text, ?) AS similarity FROM obligations WHERE text % ? ORDER BY similarity DESC ` - if err := db.DB.Raw(rawQuery, req.Text, req.Text).Scan(&results).Error; err != nil { + // SET LOCAL scopes the threshold to this transaction only, preventing race conditions + // when multiple requests adjust the threshold concurrently (unlike the previous SET which + // was session-scoped and leaked across concurrent connections in the pool). + if err := db.DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec(fmt.Sprintf("SET LOCAL pg_trgm.similarity_threshold = %.4f", threshold)).Error; err != nil { + return err + } + return tx.Raw(rawQuery, req.Text, req.Text).Scan(&results).Error + }); err != nil { er := models.LicenseError{ Status: http.StatusBadRequest, Message: "Database query failed", diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 1cf1d31..7697949 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -716,15 +716,15 @@ func DeleteUser(c *gin.Context) { // @Security ApiKeyAuth // @Router /users [get] func GetAllUser(c *gin.Context) { - active, err := strconv.ParseBool(c.Query("active")) - if err != nil { - active = false - } - var users []models.User query := db.DB.Model(&models.User{}) _ = utils.PreparePaginateResponse(c, query, &models.UserResponse{}) - if err := query.Where(&models.User{Active: &active}).Find(&users).Error; err != nil { + if activeStr := c.Query("active"); activeStr != "" { + if active, err := strconv.ParseBool(activeStr); err == nil { + query = query.Where(&models.User{Active: &active}) + } + } + if err := query.Find(&users).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: "Users not found", diff --git a/pkg/models/licenses.go b/pkg/models/licenses.go index c7c9653..3b292a9 100644 --- a/pkg/models/licenses.go +++ b/pkg/models/licenses.go @@ -11,6 +11,7 @@ package models import ( "encoding/json" "errors" + "log" "strconv" "time" @@ -252,7 +253,7 @@ func (dto *LicenseImportDTO) ConvertToLicenseDB() LicenseDB { bytes, _ := json.Marshal(dto.ExternalRef) if err := json.Unmarshal(bytes, &ext); err != nil { - panic(err) + log.Printf("failed to unmarshal external_ref for license %v: %v", dto.Shortname, err) } l.ExternalRef = datatypes.NewJSONType(ext) diff --git a/pkg/models/obligations.go b/pkg/models/obligations.go index 3d06213..a0c7059 100644 --- a/pkg/models/obligations.go +++ b/pkg/models/obligations.go @@ -12,6 +12,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "github.com/google/uuid" "gorm.io/datatypes" @@ -67,7 +68,7 @@ func (o *Obligation) BeforeCreate(tx *gorm.DB) (err error) { } } } - if o.Type.Id.String() == "" { + if o.Type.Id == uuid.Nil || o.Type.Id.String() == "" { return fmt.Errorf("obligation type must be one of the following values:%s", allTypes) } } else { @@ -88,7 +89,7 @@ func (o *Obligation) BeforeCreate(tx *gorm.DB) (err error) { } } } - if o.Classification.Id.String() == "" { + if o.Classification.Id == uuid.Nil || o.Classification.Id.String() == "" { return fmt.Errorf("obligation classification must be one of the following values:%s", allClassifications) } } else { @@ -109,7 +110,7 @@ func (o *Obligation) BeforeCreate(tx *gorm.DB) (err error) { } } } - if o.Classification.Id.String() == "" { + if o.Category.Id == uuid.Nil || o.Category.Id.String() == "" { return fmt.Errorf("obligation category must be one of the following values:%s", allCategories) } } else { @@ -138,7 +139,7 @@ func (o *Obligation) BeforeUpdate(tx *gorm.DB) (err error) { } } } - if o.Type.Id.String() == "" { + if o.Type.Id == uuid.Nil || o.Type.Id.String() == "" { return fmt.Errorf("obligation type must be one of the following values:%s", allTypes) } } @@ -159,7 +160,7 @@ func (o *Obligation) BeforeUpdate(tx *gorm.DB) (err error) { } } } - if o.Classification.Id.String() == "" { + if o.Classification.Id == uuid.Nil || o.Classification.Id.String() == "" { return fmt.Errorf("obligation classification must be one of the following values:%s", allClassifications) } } @@ -177,7 +178,7 @@ func (o *Obligation) BeforeUpdate(tx *gorm.DB) (err error) { } } } - if o.Classification.Id.String() == "" { + if o.Category.Id == uuid.Nil || o.Category.Id.String() == "" { return fmt.Errorf("obligation category must be one of the following values:%s", allCategories) } } @@ -336,7 +337,7 @@ func (obDto *ObligationFileDTO) ConvertToObligation() Obligation { bytes, _ := json.Marshal(obDto.ExternalRef) if err := json.Unmarshal(bytes, &ext); err != nil { - panic(err) + log.Printf("failed to unmarshal external_ref for obligation %v: %v", obDto.Topic, err) } o.ExternalRef = datatypes.NewJSONType(ext) diff --git a/pkg/utils/util.go b/pkg/utils/util.go index 61a52f3..0501927 100644 --- a/pkg/utils/util.go +++ b/pkg/utils/util.go @@ -672,26 +672,19 @@ func Populatedb(datafile string) { } } -// SetSimilarityThreshold parses the env var and sets the threshold in Postgres. -func SetSimilarityThreshold() { - defaultThreshold := 0.7 - thresholdStr := os.Getenv("SIMILARITY_THRESHOLD") - - threshold := defaultThreshold - if thresholdStr != "" { - if parsed, err := strconv.ParseFloat(thresholdStr, 64); err == nil { - threshold = parsed - } else { - log.Printf("Invalid SIMILARITY_THRESHOLD '%s', using default %.1f", thresholdStr, defaultThreshold) - } - } else { - log.Printf("SIMILARITY_THRESHOLD not set, using default %.1f", defaultThreshold) - } - - query := fmt.Sprintf("SET pg_trgm.similarity_threshold = %f", threshold) - if err := db.DB.Exec(query).Error; err != nil { - log.Println("Failed to set similarity threshold:", err) +// GetSimilarityThreshold parses the env var and returns the threshold value. +func GetSimilarityThreshold() float64 { + const defaultThreshold = 0.7 + thresholdStr := strings.TrimSpace(os.Getenv("SIMILARITY_THRESHOLD")) + if thresholdStr == "" { + return defaultThreshold + } + parsed, err := strconv.ParseFloat(thresholdStr, 64) + if err != nil { + log.Printf("Invalid SIMILARITY_THRESHOLD '%s', using default %.1f", thresholdStr, defaultThreshold) + return defaultThreshold } + return parsed } // GetAuditEntity is an utility function to fetch obligation or license associated with an audit @@ -792,7 +785,7 @@ func GetKid(token string) (string, error) { parts := strings.Split(token, ".") - decodedBytes, err := base64.StdEncoding.DecodeString(parts[0]) + decodedBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { return "", err }