@@ -135,10 +135,10 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
135135 return c
136136}
137137
138- func (c * Client ) send (sid uint32 , mt messageType , flags uint8 , b []byte ) error {
138+ func (c * Client ) send (ctx context. Context , sid uint32 , mt messageType , flags uint8 , b []byte ) error {
139139 c .sendLock .Lock ()
140140 defer c .sendLock .Unlock ()
141- return c .channel .send (sid , mt , flags , b )
141+ return c .channel .send (ctx , sid , mt , flags , b )
142142}
143143
144144// Call makes a unary request and returns with response
@@ -214,7 +214,7 @@ func (cs *clientStream) CloseSend() error {
214214 if cs .localClosed {
215215 return ErrStreamClosed
216216 }
217- err := cs .s .send (messageTypeData , flagRemoteClosed | flagNoData , nil )
217+ err := cs .s .send (cs . ctx , messageTypeData , flagRemoteClosed | flagNoData , nil )
218218 if err != nil {
219219 return filterCloseErr (err )
220220 }
@@ -241,7 +241,7 @@ func (cs *clientStream) SendMsg(m interface{}) error {
241241 }
242242 }
243243
244- err = cs .s .send (messageTypeData , 0 , payload )
244+ err = cs .s .send (cs . ctx , messageTypeData , 0 , payload )
245245 if err != nil {
246246 return filterCloseErr (err )
247247 }
@@ -384,9 +384,9 @@ func (c *Client) receiveLoop() error {
384384 }
385385}
386386
387- // createStream creates a new stream and registers it with the client
388- // Introduce stream types for multiple or single response
389- func (c * Client ) createStream ( flags uint8 , b []byte ) (* stream , error ) {
387+ // createStreamWithContext creates a new stream and registers it with the client.
388+ // Introduce stream types for multiple or single response.
389+ func (c * Client ) createStreamWithContext ( ctx context. Context , flags uint8 , b []byte ) (* stream , error ) {
390390 // sendLock must be held across both allocation of the stream ID and sending it across the wire.
391391 // This ensures that new stream IDs sent on the wire are always increasing, which is a
392392 // requirement of the TTRPC protocol.
@@ -426,8 +426,12 @@ func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
426426 return nil , err
427427 }
428428
429- if err := c .channel .send (uint32 (s .id ), messageTypeRequest , flags , b ); err != nil {
430- return s , filterCloseErr (err )
429+ if err := c .channel .send (ctx , uint32 (s .id ), messageTypeRequest , flags , b ); err != nil {
430+ c .streamLock .Lock ()
431+ delete (c .streams , s .id )
432+ c .streamLock .Unlock ()
433+ s .closeWithError (err )
434+ return nil , filterCloseErr (err )
431435 }
432436
433437 return s , nil
@@ -517,7 +521,7 @@ func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, metho
517521 } else {
518522 flags = flagRemoteClosed
519523 }
520- s , err := c .createStream ( flags , p )
524+ s , err := c .createStreamWithContext ( ctx , flags , p )
521525 if err != nil {
522526 return nil , err
523527 }
@@ -536,7 +540,7 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
536540 return err
537541 }
538542
539- s , err := c .createStream ( 0 , p )
543+ s , err := c .createStreamWithContext ( ctx , 0 , p )
540544 if err != nil {
541545 return err
542546 }
0 commit comments