Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions kmipclient/dialer_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
type connectionEntry struct {
url string
lastError time.Time
// mu protects concurrent access to lastError only
mu sync.Mutex
}

func WithRetryTimeout(retryTimeout time.Duration) Option {
Expand All @@ -29,6 +31,10 @@ func DialCluster(addrs []string, options ...Option) (*Client, error) {
}

func DialClusterContext(ctx context.Context, addrs []string, options ...Option) (*Client, error) {
if len(addrs) == 0 {
return nil, fmt.Errorf("at least one server address is required")
}

opts := opts{}
for _, o := range options {
if err := o(&opts); err != nil {
Expand All @@ -48,10 +54,10 @@ func DialClusterContext(ctx context.Context, addrs []string, options ...Option)
return nil, err
}

servers := make([]connectionEntry, 0, len(addrs))
servers := make([]*connectionEntry, 0, len(addrs))
for _, url := range addrs {
slog.Info("Add server to pool", "url", url)
servers = append(servers, connectionEntry{
servers = append(servers, &connectionEntry{
url: url,
})
}
Expand All @@ -68,31 +74,49 @@ func DialClusterContext(ctx context.Context, addrs []string, options ...Option)
Config: tlsCfg,
}
for _, s := range servers {
if !time.Now().After(s.lastError.Add(*opts.retryTimeout)) {
slog.Info("Skipping server because of recent last error", "url", s.url, "last error", s.lastError)
// TOCTOU: the lastError snapshot is released before DialContext and
// re-acquired only for the error write. Two goroutines can therefore
// dial the same server concurrently. This is intentional, holding the
// lock across I/O would serialize reconnections across cloned clients,
// and benign because the write is an idempotent time.Now() and the
// skip check is advisory.
s.mu.Lock()
if time.Since(s.lastError) < *opts.retryTimeout {
lastErr := s.lastError
s.mu.Unlock()
slog.Info("Skipping server because of recent last error", "url", s.url, "last_error", lastErr)
continue
}
s.mu.Unlock()

conn, err := tlsDialer.DialContext(ctx, "tcp", s.url)
if err != nil {
s.lastError = time.Now()
now := time.Now()
s.mu.Lock()
s.lastError = now
s.mu.Unlock()
slog.Warn("TLS session initialization failed", "url", s.url, "error", err)
} else {
return conn, nil
}
}

// All server had an error since retryTimeout
// All servers have had an error within retryTimeout
// Call the first server to check if it went back up
conn, err := tlsDialer.DialContext(ctx, "tcp", servers[0].url)
first := servers[0]
conn, err := tlsDialer.DialContext(ctx, "tcp", first.url)
first.mu.Lock()

if err == nil {
// reset lastError since we had a success
servers[0].lastError = time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)
first.lastError = time.Time{}
first.mu.Unlock()
return conn, nil
}
servers[0].lastError = time.Now()
slog.Warn("TLS session initialization failed", "url", servers[0].url, "error", err)
return nil, fmt.Errorf("Failed to connect to servers in the connection pool")
first.lastError = time.Now()
first.mu.Unlock()
slog.Warn("TLS session initialization failed", "url", first.url, "error", err)
return nil, fmt.Errorf("failed to connect to servers in the connection pool (last attempt %q): %w", first.url, err)
}
}

Expand Down
60 changes: 60 additions & 0 deletions kmipclient/dialer_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kmipclient_test
import (
"context"
"log/slog"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -199,6 +200,65 @@ func TestDialCluster_DefaultRetryTimeout(t *testing.T) {
defer client.Close()
}

// TestClientConnectionPool_ConcurrentDial exercises the dialer closure from many
// goroutines at once to catch data races on the shared servers slice. The first
// server is shut down so that error-path writes to lastError happen in parallel
// with skip-check reads from the other goroutines. Run with -race to catch regressions.
func TestClientConnectionPool_ConcurrentDial(t *testing.T) {
addrs := make([]string, 0, 3)
var combinedCA []byte
servers := make([]*kmipserver.Server, 3)
for i := 0; i < 3; i++ {
h := &simpleHandler{}
addr, ca, srv := kmiptest.NewServerWithHandle(t, h)
h.id = addr
addrs = append(addrs, addr)
combinedCA = append(combinedCA, []byte(ca)...)
servers[i] = srv
}

client, err := kmipclient.DialCluster(addrs,
kmipclient.WithRootCAPem(combinedCA),
kmipclient.WithRetryTimeout(50*time.Millisecond))
require.NoError(t, err)
defer client.Close()

// Shut the first server so every clone's initial dial attempt fails on it,
// forcing concurrent writes to servers[0].lastError.
require.NoError(t, servers[0].Shutdown())
time.Sleep(20 * time.Millisecond)

const N = 32
var wg sync.WaitGroup
errCh := make(chan error, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c, err := client.Clone()
if err != nil {
errCh <- err
return
}
_ = c.Close()
}()
Comment thread
phsym marked this conversation as resolved.
}
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
}

// TestDialCluster_EmptyAddrs verifies DialCluster rejects an empty address list
// instead of panicking later on servers[0] access.
func TestDialCluster_EmptyAddrs(t *testing.T) {
_, err := kmipclient.DialCluster(nil)
require.Error(t, err)
_, err = kmipclient.DialCluster([]string{})
require.Error(t, err)
}

func TestClientConnectionPool_IntermittentRecovery(t *testing.T) {
// Start server with a reusable TLS listener (generate cert here so we can restart on same addr)
// generate cert
Expand Down
Loading