Skip to content

Commit 5086005

Browse files
committed
wip: add upstream heartbeats
1 parent f693650 commit 5086005

6 files changed

Lines changed: 281 additions & 48 deletions

File tree

v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClien
114114
UpgradeClient: cfg.UpgradeClient,
115115
StreamingClient: cfg.StreamingClient,
116116
Logger: cfg.Logger,
117+
PingInterval: cfg.PingInterval,
118+
PingTimeout: cfg.PingTimeout,
117119
}),
118120
}
119121
}

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client/client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"net/http"
7+
"time"
78

89
"github.com/jensneuse/abstractlogger"
910

@@ -35,6 +36,8 @@ type Config struct {
3536
UpgradeClient *http.Client
3637
StreamingClient *http.Client
3738
Logger abstractlogger.Logger
39+
PingInterval time.Duration
40+
PingTimeout time.Duration
3841
}
3942

4043
// New creates a new subscription client with the provided config.
@@ -53,7 +56,7 @@ func New(ctx context.Context, cfg Config) *Client {
5356
ctx: ctx,
5457
log: cfg.Logger,
5558

56-
ws: transport.NewWSTransport(ctx, cfg.UpgradeClient, cfg.Logger),
59+
ws: transport.NewWSTransport(ctx, cfg.UpgradeClient, cfg.Logger, cfg.PingInterval, cfg.PingTimeout),
5760
sse: transport.NewSSETransport(ctx, cfg.StreamingClient, cfg.Logger),
5861
}
5962

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ type WSConnection struct {
2727
protocol protocol.Protocol
2828
log abstractlogger.Logger
2929

30-
writeMu sync.Mutex
31-
3230
subsMu sync.RWMutex
3331
subs map[string]chan<- *common.Message
3432

@@ -39,14 +37,19 @@ type WSConnection struct {
3937
onEmpty func()
4038

4139
WriteTimeout time.Duration
40+
41+
// Ping/pong tracking for client-initiated heartbeats.
42+
// Values stored as UnixNano timestamps.
43+
lastPingSentAt atomic.Int64
44+
lastPongAt atomic.Int64
4245
}
4346

4447
func NewWSConnection(ctx context.Context, conn *websocket.Conn, protocol protocol.Protocol, log abstractlogger.Logger, onEmpty func()) *WSConnection {
4548
if log == nil {
4649
log = abstractlogger.NoopLogger
4750
}
4851

49-
return &WSConnection{
52+
c := &WSConnection{
5053
ctx: ctx,
5154
conn: conn,
5255
protocol: protocol,
@@ -57,6 +60,10 @@ func NewWSConnection(ctx context.Context, conn *websocket.Conn, protocol protoco
5760

5861
WriteTimeout: DefaultWriteTimeout,
5962
}
63+
64+
c.lastPongAt.Store(time.Now().UnixNano())
65+
66+
return c
6067
}
6168

6269
func (c *WSConnection) Subscribe(ctx context.Context, id string, req *common.Request) (<-chan *common.Message, func(), error) {
@@ -77,9 +84,7 @@ func (c *WSConnection) Subscribe(ctx context.Context, id string, req *common.Req
7784
c.subs[id] = ch
7885
c.subsMu.Unlock()
7986

80-
if err := c.withWriteLock(func() error {
81-
return c.protocol.Subscribe(ctx, c.conn, id, req)
82-
}); err != nil {
87+
if err := c.protocol.Subscribe(ctx, c.conn, id, req); err != nil {
8388
c.log.Error("wsConnection.Subscribe",
8489
abstractlogger.String("id", id),
8590
abstractlogger.Error(err),
@@ -110,9 +115,6 @@ func (c *WSConnection) removeSub(id string) {
110115
}
111116

112117
if isEmpty {
113-
if c.onEmpty != nil {
114-
c.onEmpty()
115-
}
116118
c.Close()
117119
}
118120
}
@@ -131,24 +133,11 @@ func (c *WSConnection) unsubscribe(id string) {
131133
unsubscribeCtx, cancel := context.WithTimeout(context.Background(), c.WriteTimeout)
132134
defer cancel()
133135

134-
_ = c.withWriteLock(func() error {
135-
return c.protocol.Unsubscribe(unsubscribeCtx, c.conn, id)
136-
})
136+
_ = c.protocol.Unsubscribe(unsubscribeCtx, c.conn, id)
137137

138138
c.removeSub(id)
139139
}
140140

141-
func (c *WSConnection) withWriteLock(f func() error) error {
142-
c.writeMu.Lock()
143-
defer c.writeMu.Unlock()
144-
145-
if c.closed.Load() {
146-
return common.ErrConnectionClosed
147-
}
148-
149-
return f()
150-
}
151-
152141
func (c *WSConnection) ReadLoop() {
153142
defer c.shutdown(errors.New("read loop exited"))
154143

@@ -171,12 +160,10 @@ func (c *WSConnection) ReadLoop() {
171160
case protocol.MessagePing:
172161
c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "ping"))
173162
pongCtx, cancel := context.WithTimeout(c.ctx, c.WriteTimeout)
174-
_ = c.withWriteLock(func() error {
175-
return c.protocol.Pong(pongCtx, c.conn)
176-
})
163+
_ = c.protocol.Pong(pongCtx, c.conn)
177164
cancel()
178165
case protocol.MessagePong:
179-
// Do nothing, pongs can sometimes be used as unidirectional heartbeats
166+
c.lastPongAt.Store(time.Now().UnixNano())
180167
c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "pong"))
181168
case protocol.MessageData, protocol.MessageError, protocol.MessageComplete:
182169
c.dispatch(msg)
@@ -229,6 +216,10 @@ func (c *WSConnection) shutdown(err error) {
229216
}
230217
close(ch)
231218
}
219+
220+
if c.onEmpty != nil {
221+
c.onEmpty()
222+
}
232223
}
233224

234225
func (c *WSConnection) Close() error {
@@ -250,6 +241,30 @@ func (c *WSConnection) SubCount() int {
250241
return len(c.subs)
251242
}
252243

244+
// SendPing sends a protocol-level ping message and records the timestamp.
245+
func (c *WSConnection) SendPing(timeout time.Duration) error {
246+
pingCtx, cancel := context.WithTimeout(c.ctx, timeout)
247+
defer cancel()
248+
249+
err := c.protocol.Ping(pingCtx, c.conn)
250+
if err != nil {
251+
return err
252+
}
253+
254+
c.lastPingSentAt.Store(time.Now().UnixNano())
255+
return nil
256+
}
257+
258+
// PongOverdue returns true if a pong has not been received since the last ping
259+
// and the ping timeout has elapsed.
260+
func (c *WSConnection) PongOverdue(timeout time.Duration) bool {
261+
pingSent := c.lastPingSentAt.Load()
262+
if pingSent == 0 {
263+
return false
264+
}
265+
return c.lastPongAt.Load() < pingSent && time.Since(time.Unix(0, pingSent)) > timeout
266+
}
267+
253268
func (c *WSConnection) IsClosed() bool {
254269
return c.closed.Load()
255270
}

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,52 @@ func TestWSConnection_OnEmpty(t *testing.T) {
297297
t.Error("onEmpty should be called after last subscription removed")
298298
}
299299
})
300+
301+
t.Run("calls callback on direct Close", func(t *testing.T) {
302+
t.Parallel()
303+
304+
conn, _ := newTestConn(t)
305+
proto := newMockProtocol()
306+
307+
emptyCalled := make(chan struct{}, 1)
308+
wsc := transport.NewWSConnection(t.Context(), conn, proto, nil, func() {
309+
emptyCalled <- struct{}{}
310+
})
311+
312+
wsc.Close()
313+
314+
select {
315+
case <-emptyCalled:
316+
// success
317+
case <-time.After(100 * time.Millisecond):
318+
t.Error("onEmpty callback not called on Close")
319+
}
320+
})
321+
322+
t.Run("calls callback on read loop exit", func(t *testing.T) {
323+
t.Parallel()
324+
325+
conn, _ := newTestConn(t)
326+
proto := newMockProtocol()
327+
328+
emptyCalled := make(chan struct{}, 1)
329+
ctx, cancel := context.WithCancel(context.Background())
330+
wsc := transport.NewWSConnection(ctx, conn, proto, nil, func() {
331+
emptyCalled <- struct{}{}
332+
})
333+
334+
go wsc.ReadLoop()
335+
336+
// Cancel context to cause the read loop to exit
337+
cancel()
338+
339+
select {
340+
case <-emptyCalled:
341+
// success
342+
case <-time.After(time.Second):
343+
t.Error("onEmpty callback not called on read loop exit")
344+
}
345+
})
300346
}
301347

302348
func TestWSConnection_Close(t *testing.T) {

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/http"
88
"sync"
9+
"time"
910

1011
"github.com/coder/websocket"
1112
"github.com/jensneuse/abstractlogger"
@@ -36,6 +37,9 @@ type WSTransport struct {
3637
upgradeClient *http.Client
3738
log abstractlogger.Logger
3839

40+
pingInterval time.Duration
41+
pingTimeout time.Duration
42+
3943
mu sync.Mutex
4044
dialing map[uint64]*dialResult
4145
conns map[uint64]*WSConnection
@@ -50,7 +54,11 @@ type dialResult struct {
5054
// NewWSTransport creates a new WSTransport with the provided http.Client
5155
// for WebSocket upgrade requests. The transport will automatically close
5256
// all connections when ctx is cancelled.
53-
func NewWSTransport(ctx context.Context, upgradeClient *http.Client, log abstractlogger.Logger) *WSTransport {
57+
//
58+
// If pingInterval > 0, a single goroutine sends protocol-level pings to all
59+
// connections at that cadence. If pingTimeout > 0, connections that fail to
60+
// respond with a pong within that window after a ping are shut down.
61+
func NewWSTransport(ctx context.Context, upgradeClient *http.Client, log abstractlogger.Logger, pingInterval, pingTimeout time.Duration) *WSTransport {
5462
if log == nil {
5563
log = abstractlogger.NoopLogger
5664
}
@@ -59,12 +67,18 @@ func NewWSTransport(ctx context.Context, upgradeClient *http.Client, log abstrac
5967
ctx: ctx,
6068
upgradeClient: upgradeClient,
6169
log: log,
70+
pingInterval: pingInterval,
71+
pingTimeout: pingTimeout,
6272
conns: make(map[uint64]*WSConnection),
6373
dialing: make(map[uint64]*dialResult),
6474
}
6575

6676
context.AfterFunc(ctx, t.closeAll)
6777

78+
if pingInterval > 0 {
79+
go t.pingLoop()
80+
}
81+
6882
return t
6983
}
7084

@@ -81,9 +95,6 @@ func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts c
8195
// closeAll closes all connections. Called automatically when context is cancelled.
8296
func (t *WSTransport) closeAll() {
8397
t.mu.Lock()
84-
85-
// Copy because conn.Close -> shutdown -> onEmpty -> t.removeConn -> t.mu.Lock
86-
// would cause a deadlock
8798
conns := make([]*WSConnection, 0, len(t.conns))
8899
for _, conn := range t.conns {
89100
conns = append(conns, conn)
@@ -102,6 +113,46 @@ func (t *WSTransport) closeAll() {
102113
}
103114
}
104115

116+
// pingLoop sends periodic pings to all active connections and shuts down
117+
// any that have not responded with a pong in time.
118+
func (t *WSTransport) pingLoop() {
119+
tick := time.Tick(t.pingInterval)
120+
for {
121+
select {
122+
case <-t.ctx.Done():
123+
return
124+
case <-tick:
125+
t.mu.Lock()
126+
conns := make([]*WSConnection, 0, len(t.conns))
127+
for _, conn := range t.conns {
128+
conns = append(conns, conn)
129+
}
130+
t.mu.Unlock()
131+
132+
for _, conn := range conns {
133+
if conn.IsClosed() {
134+
continue
135+
}
136+
137+
if t.pingTimeout > 0 && conn.PongOverdue(t.pingTimeout) {
138+
t.log.Debug("wsTransport.pingLoop",
139+
abstractlogger.String("action", "pong_timeout"),
140+
)
141+
conn.Close()
142+
continue
143+
}
144+
145+
if err := conn.SendPing(DefaultWriteTimeout); err != nil {
146+
t.log.Debug("wsTransport.pingLoop",
147+
abstractlogger.String("action", "ping_failed"),
148+
abstractlogger.Error(err),
149+
)
150+
}
151+
}
152+
}
153+
}
154+
}
155+
105156
func (t *WSTransport) ConnCount() int {
106157
t.mu.Lock()
107158
defer t.mu.Unlock()

0 commit comments

Comments
 (0)