diff --git a/client.go b/client.go index 6d16696def..6b1fd7cff6 100644 --- a/client.go +++ b/client.go @@ -216,6 +216,10 @@ type Client struct { // This field is only effective within the range of MaxIdemponentCallAttempts. RetryIfErr RetryIfErrFunc + // RetryIfErrUpstream works just like RetryIfErr but also provides information about which upstream caused the error, if known. + // Upstream information is a : format. + RetryIfErrUpstream RetryIfErrUpstreamFunc + // ConfigureClient configures the fasthttp.HostClient. ConfigureClient func(hc *HostClient) error @@ -565,6 +569,7 @@ func (c *Client) hostClient(host []byte, isTLS bool) (*HostClient, error) { MaxConnWaitTimeout: c.MaxConnWaitTimeout, RetryIf: c.RetryIf, RetryIfErr: c.RetryIfErr, + RetryIfErrUpstream: c.RetryIfErrUpstream, ConnPoolStrategy: c.ConnPoolStrategy, StreamResponseBody: c.StreamResponseBody, clientReaderPool: &c.readerPool, @@ -693,6 +698,12 @@ type RetryIfFunc func(request *Request) bool // the request function will immediately return with the `err`. type RetryIfErrFunc func(request *Request, attempts int, err error) (resetTimeout bool, retry bool) +// RetryIfErrUpstreamFunc works just like a RetryIfErrFunc and also provides +// information about which upstream caused the error, if known. +// +// Upstream information is a : format. +type RetryIfErrUpstreamFunc func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) + // RoundTripper wraps every request/response. type RoundTripper interface { RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) @@ -754,6 +765,10 @@ type HostClient struct { // This field is only effective within the range of MaxIdemponentCallAttempts. RetryIfErr RetryIfErrFunc + // RetryIfErrUpstream works just like RetryIfErr but also provides information about which upstream causes the error, if known. + // Upstream information is a : format. + RetryIfErrUpstream RetryIfErrUpstreamFunc + connsWait *wantConnQueue tlsConfigMap map[string]*tls.Config @@ -1393,9 +1408,16 @@ func (c *HostClient) Do(req *Request, resp *Response) error { if attempts >= maxAttempts { break } - if c.RetryIfErr != nil { + switch { + case c.RetryIfErrUpstream != nil: + upstream := "" + if resp.RemoteAddr() != nil { + upstream = resp.RemoteAddr().String() + } + resetTimeout, retry = c.RetryIfErrUpstream(req, attempts, err, upstream) + case c.RetryIfErr != nil: resetTimeout, retry = c.RetryIfErr(req, attempts, err) - } else { + default: retry = retryFunc(req) } if !retry { diff --git a/client_test.go b/client_test.go index af45b99f62..6e93a19e47 100644 --- a/client_test.go +++ b/client_test.go @@ -3615,3 +3615,79 @@ func (r *testResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPA r.lookupCountByHost[host]++ return r.resolver.LookupIPAddr(ctx, host) } + +type TransportMock struct { + wrapperFunc func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) +} + +func (t *TransportMock) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) { + return t.wrapperFunc(hc, req, resp) +} + +func TestClient_RetryIfErrUpstream(t *testing.T) { + t.Parallel() + upstreamErr := errors.New("upstream error") + + t.Run("upstream_known", func(t *testing.T) { + retryIfErrCalled := false + c := &Client{ + Transport: &TransportMock{ + wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) { + resp.raddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080} + return true, upstreamErr + }, + }, + RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) { + retryIfErrCalled = true + if upstream != "127.0.0.1:8080" { + t.Errorf("expected upstream to be 127.0.0.1:8080, got %s", upstream) + } + + return false, false + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + err := c.Do(req, res) + if !errors.Is(err, upstreamErr) { + t.Fatal(err) + } + if !retryIfErrCalled { + t.Fatal("RetryIfErrUpstream should be called") + } + }) + + t.Run("no_upstream", func(t *testing.T) { + retryIfErrCalled := false + c := &Client{ + Transport: &TransportMock{ + wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) { + return true, upstreamErr + }, + }, + RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) { + retryIfErrCalled = true + if upstream != "" { + t.Errorf("expected upstream to be empty, got %s", upstream) + } + + return false, false + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + err := c.Do(req, res) + if !errors.Is(err, upstreamErr) { + t.Fatal(err) + } + if !retryIfErrCalled { + t.Fatal("RetryIfErrUpstream should be called") + } + }) +}