diff --git a/service/entityresolution/multi-strategy/providers/sql/sql_config.go b/service/entityresolution/multi-strategy/providers/sql/sql_config.go index 31e2439588..08c7df4bd4 100644 --- a/service/entityresolution/multi-strategy/providers/sql/sql_config.go +++ b/service/entityresolution/multi-strategy/providers/sql/sql_config.go @@ -6,6 +6,7 @@ import ( const ( // Default SQL configuration values + defaultPostgreSQLDriver = "postgres" defaultPostgreSQLPort = 5432 defaultMaxOpenConnections = 25 defaultMaxIdleConnections = 5 @@ -45,7 +46,7 @@ type Config struct { // DefaultConfig returns a default SQL configuration func DefaultConfig() Config { return Config{ - Driver: "postgres", + Driver: defaultPostgreSQLDriver, Port: defaultPostgreSQLPort, SSLMode: "require", MaxOpenConnections: defaultMaxOpenConnections, diff --git a/service/entityresolution/multi-strategy/providers/sql/sql_provider.go b/service/entityresolution/multi-strategy/providers/sql/sql_provider.go index 47f4328edc..1cbd3324f0 100644 --- a/service/entityresolution/multi-strategy/providers/sql/sql_provider.go +++ b/service/entityresolution/multi-strategy/providers/sql/sql_provider.go @@ -5,15 +5,57 @@ import ( "database/sql" "fmt" "strings" + "sync" - // Database drivers would be imported here: - // _ "github.com/lib/pq" // PostgreSQL driver - // _ "github.com/go-sql-driver/mysql" // MySQL driver - // _ "github.com/mattn/go-sqlite3" // SQLite driver - + "github.com/jackc/pgx/v5/stdlib" "github.com/opentdf/platform/service/entityresolution/multi-strategy/types" ) +var ( + // driverRegMu guards lazy driver registration to prevent duplicate-register panics. + driverRegMu sync.Mutex + registeredDrivers = make(map[string]struct{}) +) + +// ensureDriverRegistered lazily registers the named database/sql driver the first +// time a SQL provider for that driver is created. This avoids the need for +// consumers to add blank driver imports to their own binaries. +// +// Uses pgx/v5/stdlib for postgres (already a platform dependency). Other drivers +// (mysql, sqlite) are not currently auto-registered and must be imported by the +// consumer. Consumers that have already registered the driver themselves are +// handled gracefully via a sql.Drivers() pre-check. +func ensureDriverRegistered(driver string) { + // Normalize to lowercase so "Postgres", "POSTGRES", and "postgres" all resolve + // to the same registered driver name. sql.Register is case-sensitive. + driver = strings.ToLower(strings.TrimSpace(driver)) + + driverRegMu.Lock() + defer driverRegMu.Unlock() + + if _, ok := registeredDrivers[driver]; ok { + return + } + + // Check whether the driver was already registered externally (e.g. via a + // blank import in the consumer binary) before attempting to register it. + // database/sql driver names are case-sensitive, so only an exact canonical + // match can satisfy sql.Open after the config driver is normalized. + for _, d := range sql.Drivers() { + if d == driver { + registeredDrivers[driver] = struct{}{} + return + } + } + + if driver == defaultPostgreSQLDriver { + sql.Register(defaultPostgreSQLDriver, stdlib.GetDefaultDriver()) + registeredDrivers[driver] = struct{}{} + } + // mysql and sqlite require imports not present in this module's dependencies. + // Add cases here when those drivers are added to go.mod. +} + // Provider implements the Provider interface for SQL databases type Provider struct { name string @@ -24,6 +66,15 @@ type Provider struct { // NewProvider creates a new SQL provider func NewProvider(ctx context.Context, name string, config Config) (*Provider, error) { + // Normalize the driver name so "Postgres", "POSTGRES", and "postgres" all + // resolve correctly through ensureDriverRegistered and sql.Open, both of + // which use case-sensitive driver name matching. + config.Driver = strings.ToLower(strings.TrimSpace(config.Driver)) + + // Register the database/sql driver for this provider's configured driver name + // if it has not already been registered. + ensureDriverRegistered(config.Driver) + provider := &Provider{ name: name, config: config, @@ -254,7 +305,7 @@ func (p *Provider) Close() error { // buildConnectionString creates a connection string based on the driver func (p *Provider) buildConnectionString() (string, error) { switch strings.ToLower(p.config.Driver) { - case "postgres": + case defaultPostgreSQLDriver: return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", p.config.Host, p.config.Port, p.config.Username, p.config.Password, p.config.Database, p.config.SSLMode), nil