Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

const (
// Default SQL configuration values
defaultPostgreSQLDriver = "postgres"
defaultPostgreSQLPort = 5432
defaultMaxOpenConnections = 25
defaultMaxIdleConnections = 5
Expand Down Expand Up @@ -45,7 +46,7 @@ type Config struct {
// DefaultConfig returns a default SQL configuration
func DefaultConfig() Config {
return Config{
Driver: "postgres",
Driver: defaultPostgreSQLDriver,
Port: defaultPostgreSQLPort,
SSLMode: "require",
MaxOpenConnections: defaultMaxOpenConnections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,56 @@ import (
"database/sql"
"fmt"
"strings"
"sync"

// Database drivers would be imported here:
// _ "github.com/lib/pq" // PostgreSQL driver
// _ "github.com/go-sql-driver/mysql" // MySQL driver
// _ "github.com/mattn/go-sqlite3" // SQLite driver

"github.com/jackc/pgx/v5/stdlib"
"github.com/opentdf/platform/service/entityresolution/multi-strategy/types"
)

var (
// driverRegMu guards lazy driver registration to prevent duplicate-register panics.
driverRegMu sync.Mutex
registeredDrivers = make(map[string]struct{})
)

// ensureDriverRegistered lazily registers the named database/sql driver the first
// time a SQL provider for that driver is created. This avoids the need for
// consumers to add blank driver imports to their own binaries.
//
// Uses pgx/v5/stdlib for postgres (already a platform dependency). Other drivers
// (mysql, sqlite) are not currently auto-registered and must be imported by the
// consumer. Consumers that have already registered the driver themselves are
// handled gracefully via a sql.Drivers() pre-check.
func ensureDriverRegistered(driver string) {
// Normalize to lowercase so "Postgres", "POSTGRES", and "postgres" all resolve
// to the same registered driver name. sql.Register is case-sensitive.
driver = strings.ToLower(strings.TrimSpace(driver))
Comment thread
coderabbitai[bot] marked this conversation as resolved.

driverRegMu.Lock()
defer driverRegMu.Unlock()

if _, ok := registeredDrivers[driver]; ok {
return
}

// Check whether the driver was already registered externally (e.g. via a
// blank import in the consumer binary) before attempting to register it.
// Use strings.EqualFold so the pre-check is also case-insensitive.
for _, d := range sql.Drivers() {
if strings.EqualFold(d, driver) {
registeredDrivers[driver] = struct{}{}
return
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}

if driver == defaultPostgreSQLDriver {
sql.Register(defaultPostgreSQLDriver, stdlib.GetDefaultDriver())
registeredDrivers[driver] = struct{}{}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
// mysql and sqlite require imports not present in this module's dependencies.
// Add cases here when those drivers are added to go.mod.
}

// Provider implements the Provider interface for SQL databases
type Provider struct {
name string
Expand All @@ -24,6 +65,15 @@ type Provider struct {

// NewProvider creates a new SQL provider
func NewProvider(ctx context.Context, name string, config Config) (*Provider, error) {
// Normalize the driver name so "Postgres", "POSTGRES", and "postgres" all
// resolve correctly through ensureDriverRegistered and sql.Open, both of
// which use case-sensitive driver name matching.
config.Driver = strings.ToLower(strings.TrimSpace(config.Driver))

// Register the database/sql driver for this provider's configured driver name
// if it has not already been registered.
ensureDriverRegistered(config.Driver)
Comment thread
jp-ayyappan marked this conversation as resolved.

provider := &Provider{
name: name,
config: config,
Expand Down Expand Up @@ -254,7 +304,7 @@ func (p *Provider) Close() error {
// buildConnectionString creates a connection string based on the driver
func (p *Provider) buildConnectionString() (string, error) {
switch strings.ToLower(p.config.Driver) {
case "postgres":
case defaultPostgreSQLDriver:
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
p.config.Host, p.config.Port, p.config.Username, p.config.Password,
p.config.Database, p.config.SSLMode), nil
Expand Down
Loading