Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ type Client struct {
// This field is only effective within the range of MaxIdemponentCallAttempts.
RetryIfErr RetryIfErrFunc

// RetryIfErrUpstream works just like RetryIfErr but also provides information about which upstream caused the error, if known.
// Upstream information is a <host>:<port> format.
RetryIfErrUpstream RetryIfErrUpstreamFunc

// ConfigureClient configures the fasthttp.HostClient.
ConfigureClient func(hc *HostClient) error

Expand Down Expand Up @@ -565,6 +569,7 @@ func (c *Client) hostClient(host []byte, isTLS bool) (*HostClient, error) {
MaxConnWaitTimeout: c.MaxConnWaitTimeout,
RetryIf: c.RetryIf,
RetryIfErr: c.RetryIfErr,
RetryIfErrUpstream: c.RetryIfErrUpstream,
ConnPoolStrategy: c.ConnPoolStrategy,
StreamResponseBody: c.StreamResponseBody,
clientReaderPool: &c.readerPool,
Expand Down Expand Up @@ -693,6 +698,12 @@ type RetryIfFunc func(request *Request) bool
// the request function will immediately return with the `err`.
type RetryIfErrFunc func(request *Request, attempts int, err error) (resetTimeout bool, retry bool)

// RetryIfErrUpstreamFunc works just like a RetryIfErrFunc and also provides
// information about which upstream caused the error, if known.
//
// Upstream information is a <host>:<port> format.
type RetryIfErrUpstreamFunc func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool)

// RoundTripper wraps every request/response.
type RoundTripper interface {
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
Expand Down Expand Up @@ -754,6 +765,10 @@ type HostClient struct {
// This field is only effective within the range of MaxIdemponentCallAttempts.
RetryIfErr RetryIfErrFunc

// RetryIfErrUpstream works just like RetryIfErr but also provides information about which upstream causes the error, if known.
// Upstream information is a <host>:<port> format.
RetryIfErrUpstream RetryIfErrUpstreamFunc

connsWait *wantConnQueue

tlsConfigMap map[string]*tls.Config
Expand Down Expand Up @@ -1393,9 +1408,16 @@ func (c *HostClient) Do(req *Request, resp *Response) error {
if attempts >= maxAttempts {
break
}
if c.RetryIfErr != nil {
switch {
case c.RetryIfErrUpstream != nil:
upstream := ""
if resp.RemoteAddr() != nil {
upstream = resp.RemoteAddr().String()
}
resetTimeout, retry = c.RetryIfErrUpstream(req, attempts, err, upstream)
case c.RetryIfErr != nil:
resetTimeout, retry = c.RetryIfErr(req, attempts, err)
} else {
default:
retry = retryFunc(req)
}
if !retry {
Expand Down
76 changes: 76 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3615,3 +3615,79 @@ func (r *testResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPA
r.lookupCountByHost[host]++
return r.resolver.LookupIPAddr(ctx, host)
}

type TransportMock struct {
wrapperFunc func(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
}

func (t *TransportMock) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
return t.wrapperFunc(hc, req, resp)
}

func TestClient_RetryIfErrUpstream(t *testing.T) {
t.Parallel()
upstreamErr := errors.New("upstream error")

t.Run("upstream_known", func(t *testing.T) {
retryIfErrCalled := false
c := &Client{
Transport: &TransportMock{
wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
resp.raddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
return true, upstreamErr
},
},
RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) {
retryIfErrCalled = true
if upstream != "127.0.0.1:8080" {
t.Errorf("expected upstream to be 127.0.0.1:8080, got %s", upstream)
}

return false, false
},
}
req := AcquireRequest()
res := AcquireResponse()

req.SetRequestURI("http://example.com")

err := c.Do(req, res)
if !errors.Is(err, upstreamErr) {
t.Fatal(err)
}
if !retryIfErrCalled {
t.Fatal("RetryIfErrUpstream should be called")
}
})

t.Run("no_upstream", func(t *testing.T) {
retryIfErrCalled := false
c := &Client{
Transport: &TransportMock{
wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
return true, upstreamErr
},
},
RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) {
retryIfErrCalled = true
if upstream != "" {
t.Errorf("expected upstream to be empty, got %s", upstream)
}

return false, false
},
}
req := AcquireRequest()
res := AcquireResponse()

req.SetRequestURI("http://example.com")

err := c.Do(req, res)
if !errors.Is(err, upstreamErr) {
t.Fatal(err)
}
if !retryIfErrCalled {
t.Fatal("RetryIfErrUpstream should be called")
}
})
}