diff --git a/kmipclient/dialer_cluster.go b/kmipclient/dialer_cluster.go index 44d1c98..13158a6 100644 --- a/kmipclient/dialer_cluster.go +++ b/kmipclient/dialer_cluster.go @@ -15,6 +15,8 @@ import ( type connectionEntry struct { url string lastError time.Time + // mu protects concurrent access to lastError only + mu sync.Mutex } func WithRetryTimeout(retryTimeout time.Duration) Option { @@ -29,6 +31,10 @@ func DialCluster(addrs []string, options ...Option) (*Client, error) { } func DialClusterContext(ctx context.Context, addrs []string, options ...Option) (*Client, error) { + if len(addrs) == 0 { + return nil, fmt.Errorf("at least one server address is required") + } + opts := opts{} for _, o := range options { if err := o(&opts); err != nil { @@ -48,10 +54,10 @@ func DialClusterContext(ctx context.Context, addrs []string, options ...Option) return nil, err } - servers := make([]connectionEntry, 0, len(addrs)) + servers := make([]*connectionEntry, 0, len(addrs)) for _, url := range addrs { slog.Info("Add server to pool", "url", url) - servers = append(servers, connectionEntry{ + servers = append(servers, &connectionEntry{ url: url, }) } @@ -68,31 +74,49 @@ func DialClusterContext(ctx context.Context, addrs []string, options ...Option) Config: tlsCfg, } for _, s := range servers { - if !time.Now().After(s.lastError.Add(*opts.retryTimeout)) { - slog.Info("Skipping server because of recent last error", "url", s.url, "last error", s.lastError) + // TOCTOU: the lastError snapshot is released before DialContext and + // re-acquired only for the error write. Two goroutines can therefore + // dial the same server concurrently. This is intentional, holding the + // lock across I/O would serialize reconnections across cloned clients, + // and benign because the write is an idempotent time.Now() and the + // skip check is advisory. + s.mu.Lock() + if time.Since(s.lastError) < *opts.retryTimeout { + lastErr := s.lastError + s.mu.Unlock() + slog.Info("Skipping server because of recent last error", "url", s.url, "last_error", lastErr) continue } + s.mu.Unlock() conn, err := tlsDialer.DialContext(ctx, "tcp", s.url) if err != nil { - s.lastError = time.Now() + now := time.Now() + s.mu.Lock() + s.lastError = now + s.mu.Unlock() slog.Warn("TLS session initialization failed", "url", s.url, "error", err) } else { return conn, nil } } - // All server had an error since retryTimeout + // All servers have had an error within retryTimeout // Call the first server to check if it went back up - conn, err := tlsDialer.DialContext(ctx, "tcp", servers[0].url) + first := servers[0] + conn, err := tlsDialer.DialContext(ctx, "tcp", first.url) + first.mu.Lock() + if err == nil { // reset lastError since we had a success - servers[0].lastError = time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC) + first.lastError = time.Time{} + first.mu.Unlock() return conn, nil } - servers[0].lastError = time.Now() - slog.Warn("TLS session initialization failed", "url", servers[0].url, "error", err) - return nil, fmt.Errorf("Failed to connect to servers in the connection pool") + first.lastError = time.Now() + first.mu.Unlock() + slog.Warn("TLS session initialization failed", "url", first.url, "error", err) + return nil, fmt.Errorf("failed to connect to servers in the connection pool (last attempt %q): %w", first.url, err) } } diff --git a/kmipclient/dialer_cluster_test.go b/kmipclient/dialer_cluster_test.go index d8d486c..079d497 100644 --- a/kmipclient/dialer_cluster_test.go +++ b/kmipclient/dialer_cluster_test.go @@ -3,6 +3,7 @@ package kmipclient_test import ( "context" "log/slog" + "sync" "testing" "time" @@ -199,6 +200,65 @@ func TestDialCluster_DefaultRetryTimeout(t *testing.T) { defer client.Close() } +// TestClientConnectionPool_ConcurrentDial exercises the dialer closure from many +// goroutines at once to catch data races on the shared servers slice. The first +// server is shut down so that error-path writes to lastError happen in parallel +// with skip-check reads from the other goroutines. Run with -race to catch regressions. +func TestClientConnectionPool_ConcurrentDial(t *testing.T) { + addrs := make([]string, 0, 3) + var combinedCA []byte + servers := make([]*kmipserver.Server, 3) + for i := 0; i < 3; i++ { + h := &simpleHandler{} + addr, ca, srv := kmiptest.NewServerWithHandle(t, h) + h.id = addr + addrs = append(addrs, addr) + combinedCA = append(combinedCA, []byte(ca)...) + servers[i] = srv + } + + client, err := kmipclient.DialCluster(addrs, + kmipclient.WithRootCAPem(combinedCA), + kmipclient.WithRetryTimeout(50*time.Millisecond)) + require.NoError(t, err) + defer client.Close() + + // Shut the first server so every clone's initial dial attempt fails on it, + // forcing concurrent writes to servers[0].lastError. + require.NoError(t, servers[0].Shutdown()) + time.Sleep(20 * time.Millisecond) + + const N = 32 + var wg sync.WaitGroup + errCh := make(chan error, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c, err := client.Clone() + if err != nil { + errCh <- err + return + } + _ = c.Close() + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + require.NoError(t, err) + } +} + +// TestDialCluster_EmptyAddrs verifies DialCluster rejects an empty address list +// instead of panicking later on servers[0] access. +func TestDialCluster_EmptyAddrs(t *testing.T) { + _, err := kmipclient.DialCluster(nil) + require.Error(t, err) + _, err = kmipclient.DialCluster([]string{}) + require.Error(t, err) +} + func TestClientConnectionPool_IntermittentRecovery(t *testing.T) { // Start server with a reusable TLS listener (generate cert here so we can restart on same addr) // generate cert