Skip to content

Commit a881bb0

Browse files
committed
Honor context deadline in send operations to prevent indefinite hangs
Signed-off-by: novahe <heqianfly@gmail.com>
1 parent 9074e24 commit a881bb0

File tree

6 files changed

+152
-24
lines changed

6 files changed

+152
-24
lines changed

channel.go

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ package ttrpc
1818

1919
import (
2020
"bufio"
21+
"context"
2122
"encoding/binary"
23+
"errors"
2224
"fmt"
2325
"io"
2426
"net"
2527
"sync"
28+
"time"
2629

2730
"google.golang.org/grpc/codes"
2831
"google.golang.org/grpc/status"
@@ -142,23 +145,85 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
142145
return mh, p, nil
143146
}
144147

145-
func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
148+
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, flags uint8, p []byte) error {
146149
if len(p) > messageLengthMax {
147150
return OversizedMessageError(len(p))
148151
}
149152

150-
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
153+
if ctx == nil {
154+
ctx = context.Background()
155+
}
156+
157+
if err := ctx.Err(); err != nil {
151158
return err
152159
}
153160

154-
if len(p) > 0 {
155-
_, err := ch.bw.Write(p)
156-
if err != nil {
161+
if deadline, ok := ctx.Deadline(); ok {
162+
if err := ch.conn.SetWriteDeadline(deadline); err != nil {
157163
return err
158164
}
165+
} else {
166+
if err := ch.conn.SetWriteDeadline(time.Time{}); err != nil {
167+
return err
168+
}
169+
}
170+
171+
defer ch.conn.SetWriteDeadline(time.Time{})
172+
173+
if ctx.Done() != nil {
174+
done := make(chan struct{})
175+
go func() {
176+
select {
177+
case <-ctx.Done():
178+
ch.conn.SetWriteDeadline(time.Now())
179+
case <-done:
180+
}
181+
}()
182+
defer close(done)
183+
}
184+
185+
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
186+
return ch.failSend(ctx, err)
187+
}
188+
189+
if len(p) > 0 {
190+
if _, err := ch.bw.Write(p); err != nil {
191+
return ch.failSend(ctx, err)
192+
}
193+
}
194+
195+
if err := ch.bw.Flush(); err != nil {
196+
return ch.failSend(ctx, err)
197+
}
198+
199+
return nil
200+
}
201+
202+
func mapWriteTimeout(ctx context.Context, err error) error {
203+
if err == nil {
204+
return nil
159205
}
160206

161-
return ch.bw.Flush()
207+
var netErr net.Error
208+
if errors.As(err, &netErr) && netErr.Timeout() {
209+
if ctxErr := ctx.Err(); ctxErr != nil {
210+
return ctxErr
211+
}
212+
213+
// Fallback for race condition: check if we are actually past the deadline.
214+
if d, ok := ctx.Deadline(); ok && !time.Now().Before(d) {
215+
return context.DeadlineExceeded
216+
}
217+
}
218+
219+
return err
220+
}
221+
222+
func (ch *channel) failSend(ctx context.Context, err error) error {
223+
// Any write-side failure may leave buffered bytes in an indeterminate state.
224+
// Close the connection so later sends cannot corrupt the framing stream.
225+
_ = ch.conn.Close()
226+
return mapWriteTimeout(ctx, err)
162227
}
163228

164229
func (ch *channel) getmbuf(size int) []byte {

channel_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package ttrpc
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"errors"
2223
"io"
2324
"net"
@@ -44,7 +45,7 @@ func TestReadWriteMessage(t *testing.T) {
4445

4546
go func() {
4647
for i, msg := range messages {
47-
if err := ch.send(uint32(i), 1, 0, msg); err != nil {
48+
if err := ch.send(context.Background(), uint32(i), 1, 0, msg); err != nil {
4849
errs <- err
4950
return
5051
}
@@ -96,7 +97,7 @@ func TestMessageOversize(t *testing.T) {
9697
)
9798

9899
go func() {
99-
errs <- wch.send(1, 1, 0, msg)
100+
errs <- wch.send(context.Background(), 1, 1, 0, msg)
100101
}()
101102

102103
err := <-errs

client.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
135135
return c
136136
}
137137

138-
func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
138+
func (c *Client) send(ctx context.Context, sid uint32, mt messageType, flags uint8, b []byte) error {
139139
c.sendLock.Lock()
140140
defer c.sendLock.Unlock()
141-
return c.channel.send(sid, mt, flags, b)
141+
return c.channel.send(ctx, sid, mt, flags, b)
142142
}
143143

144144
// Call makes a unary request and returns with response
@@ -214,7 +214,7 @@ func (cs *clientStream) CloseSend() error {
214214
if cs.localClosed {
215215
return ErrStreamClosed
216216
}
217-
err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
217+
err := cs.s.send(cs.ctx, messageTypeData, flagRemoteClosed|flagNoData, nil)
218218
if err != nil {
219219
return filterCloseErr(err)
220220
}
@@ -241,7 +241,7 @@ func (cs *clientStream) SendMsg(m interface{}) error {
241241
}
242242
}
243243

244-
err = cs.s.send(messageTypeData, 0, payload)
244+
err = cs.s.send(cs.ctx, messageTypeData, 0, payload)
245245
if err != nil {
246246
return filterCloseErr(err)
247247
}
@@ -384,9 +384,9 @@ func (c *Client) receiveLoop() error {
384384
}
385385
}
386386

387-
// createStream creates a new stream and registers it with the client
388-
// Introduce stream types for multiple or single response
389-
func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
387+
// createStreamWithContext creates a new stream and registers it with the client.
388+
// Introduce stream types for multiple or single response.
389+
func (c *Client) createStreamWithContext(ctx context.Context, flags uint8, b []byte) (*stream, error) {
390390
// sendLock must be held across both allocation of the stream ID and sending it across the wire.
391391
// This ensures that new stream IDs sent on the wire are always increasing, which is a
392392
// requirement of the TTRPC protocol.
@@ -426,8 +426,12 @@ func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
426426
return nil, err
427427
}
428428

429-
if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
430-
return s, filterCloseErr(err)
429+
if err := c.channel.send(ctx, uint32(s.id), messageTypeRequest, flags, b); err != nil {
430+
c.streamLock.Lock()
431+
delete(c.streams, s.id)
432+
c.streamLock.Unlock()
433+
s.closeWithError(err)
434+
return nil, filterCloseErr(err)
431435
}
432436

433437
return s, nil
@@ -517,7 +521,7 @@ func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, metho
517521
} else {
518522
flags = flagRemoteClosed
519523
}
520-
s, err := c.createStream(flags, p)
524+
s, err := c.createStreamWithContext(ctx, flags, p)
521525
if err != nil {
522526
return nil, err
523527
}
@@ -536,7 +540,7 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
536540
return err
537541
}
538542

539-
s, err := c.createStream(0, p)
543+
s, err := c.createStreamWithContext(ctx, 0, p)
540544
if err != nil {
541545
return err
542546
}

client_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package ttrpc
1818

1919
import (
2020
"context"
21+
"errors"
22+
"net"
2123
"testing"
2224
"time"
2325

@@ -70,3 +72,59 @@ func TestUserOnCloseWait(t *testing.T) {
7072
t.Fatalf("expected error nil , but got %v", err)
7173
}
7274
}
75+
76+
func TestCallSendBlocked(t *testing.T) {
77+
verifyCleanup := func(t *testing.T, client *Client) {
78+
t.Helper()
79+
client.streamLock.RLock()
80+
streamsLen := len(client.streams)
81+
client.streamLock.RUnlock()
82+
if streamsLen != 0 {
83+
t.Fatalf("expected no active streams after send failure, got %d", streamsLen)
84+
}
85+
86+
waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
87+
defer waitCancel()
88+
if err := client.UserOnCloseWait(waitCtx); err != nil {
89+
t.Fatalf("expected client to close after send failure, got %v", err)
90+
}
91+
}
92+
93+
t.Run("Timeout", func(t *testing.T) {
94+
serverConn, clientConn := net.Pipe()
95+
client := NewClient(clientConn)
96+
defer serverConn.Close()
97+
defer client.Close()
98+
99+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
100+
defer cancel()
101+
102+
err := client.Call(ctx, "service", "method", &internal.TestPayload{}, &internal.TestPayload{})
103+
if !errors.Is(err, context.DeadlineExceeded) {
104+
t.Fatalf("expected error %v, got %v", context.DeadlineExceeded, err)
105+
}
106+
107+
verifyCleanup(t, client)
108+
})
109+
110+
t.Run("Cancel", func(t *testing.T) {
111+
serverConn, clientConn := net.Pipe()
112+
client := NewClient(clientConn)
113+
defer serverConn.Close()
114+
defer client.Close()
115+
116+
ctx, cancel := context.WithCancel(context.Background())
117+
defer cancel()
118+
go func() {
119+
time.Sleep(100 * time.Millisecond)
120+
cancel()
121+
}()
122+
123+
err := client.Call(ctx, "service", "method", &internal.TestPayload{}, &internal.TestPayload{})
124+
if !errors.Is(err, context.Canceled) {
125+
t.Fatalf("expected error %v, got %v", context.Canceled, err)
126+
}
127+
128+
verifyCleanup(t, client)
129+
})
130+
}

server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ func (c *serverConn) run(sctx context.Context) {
525525
return
526526
}
527527

528-
if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
528+
if err := ch.send(ctx, response.id, messageTypeResponse, 0, p); err != nil {
529529
log.G(ctx).WithError(err).Error("failed sending message on channel")
530530
return
531531
}
@@ -537,7 +537,7 @@ func (c *serverConn) run(sctx context.Context) {
537537
if response.data == nil {
538538
flags = flags | flagNoData
539539
}
540-
if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
540+
if err := ch.send(ctx, response.id, messageTypeData, flags, response.data); err != nil {
541541
log.G(ctx).WithError(err).Error("failed sending message on channel")
542542
return
543543
}

stream.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ func (s *stream) closeWithError(err error) error {
5959
return nil
6060
}
6161

62-
func (s *stream) send(mt messageType, flags uint8, b []byte) error {
63-
return s.sender.send(uint32(s.id), mt, flags, b)
62+
func (s *stream) send(ctx context.Context, mt messageType, flags uint8, b []byte) error {
63+
return s.sender.send(ctx, uint32(s.id), mt, flags, b)
6464
}
6565

6666
func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
@@ -80,5 +80,5 @@ func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
8080
}
8181

8282
type sender interface {
83-
send(uint32, messageType, uint8, []byte) error
83+
send(context.Context, uint32, messageType, uint8, []byte) error
8484
}

0 commit comments

Comments
 (0)