diff --git a/client.go b/client.go index 6d16696def..17cf27a4df 100644 --- a/client.go +++ b/client.go @@ -752,6 +752,9 @@ type HostClient struct { // and whether to reset the request timeout—should be determined // based on the return value of this field. // This field is only effective within the range of MaxIdemponentCallAttempts. + // + // Check errors matches with errors.Is/errors.As, since errors are wrapped with upstream information. + // To get upstream information from the error, check ErrWithUpstream. RetryIfErr RetryIfErrFunc connsWait *wantConnQueue @@ -1295,7 +1298,7 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { req.timeout = time.Until(deadline) if req.timeout <= 0 { - return ErrTimeout + return wrapErrWithUpstream(ErrTimeout, c.Addr) } return c.Do(req, resp) } @@ -1374,12 +1377,15 @@ func (c *HostClient) Do(req *Request, resp *Response) error { if timeout > 0 { req.timeout = time.Until(deadline) if req.timeout <= 0 { - err = ErrTimeout + err = wrapErrWithUpstream(ErrTimeout, c.Addr) break } } retry, err = c.do(req, resp) + + err = wrapErrWithUpstream(err, c.Addr) + if err == nil || !retry { break } @@ -1393,6 +1399,7 @@ func (c *HostClient) Do(req *Request, resp *Response) error { if attempts >= maxAttempts { break } + if c.RetryIfErr != nil { resetTimeout, retry = c.RetryIfErr(req, attempts, err) } else { @@ -1530,6 +1537,42 @@ func (e *timeoutError) Timeout() bool { return true } +// ErrWithUpstream wraps errors with upstream information where upstream info exists. +// Root error can be obtained via errors.Unwrap. Use errors.Is to check if root error matches. +// +// Should use errors.As to get upstream information from error: +// +// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"} +// err := hc.Do(req, res) +// +// var upstreamErr *fasthttp.ErrWithUpstream +// if errors.As(err, &upstreamErr) { +// upstream = upstreamErr.Upstream // 34.206.39.153:80 +// } +type ErrWithUpstream struct { + wrapErr error + Upstream string +} + +func (e *ErrWithUpstream) Error() string { + return fmt.Sprintf("error on upstream %s: %s", e.Upstream, e.wrapErr.Error()) +} + +func (e *ErrWithUpstream) Unwrap() error { + return e.wrapErr +} + +func wrapErrWithUpstream(err error, upstream string) error { + if err == nil { + return nil + } + + return &ErrWithUpstream{ + wrapErr: err, + Upstream: upstream, + } +} + // ErrTimeout is returned from timed out calls. var ErrTimeout = &timeoutError{} diff --git a/client_test.go b/client_test.go index af45b99f62..02dc4167ba 100644 --- a/client_test.go +++ b/client_test.go @@ -144,7 +144,7 @@ func TestHostClientNegativeTimeout(t *testing.T) { if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout { t.Fatalf("expected ErrTimeout error got: %+v", err) } - if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout { + if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout error got: %+v", err) } ln.Close() @@ -184,7 +184,7 @@ func TestDoDeadlineRetry(t *testing.T) { req := AcquireRequest() req.Header.SetMethod(MethodGet) req.SetRequestURI("http://example.com") - if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*200)); err != ErrTimeout { + if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*200)); !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout error got: %+v", err) } ln.Close() @@ -615,7 +615,7 @@ func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) { } _, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond) - if err != ErrHostClientRedirectToDifferentScheme { + if !errors.Is(err, ErrHostClientRedirectToDifferentScheme) { t.Fatal("expected HostClient error") } } @@ -781,7 +781,7 @@ func TestClientReadTimeout(t *testing.T) { req.SetRequestURI("http://localhost") req.SetConnectionClose() - if err := c.Do(req, res); err != ErrTimeout { + if err := c.Do(req, res); !errors.Is(err, ErrTimeout) { t.Errorf("expected ErrTimeout got %#v", err) } @@ -1429,7 +1429,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { for { if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil { - if err == ErrNoFreeConns { + if errors.Is(err, ErrNoFreeConns) { time.Sleep(time.Millisecond) continue } @@ -1714,7 +1714,7 @@ func TestClientFollowRedirects(t *testing.T) { if err == nil { t.Errorf("expecting error") } - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) } @@ -1907,7 +1907,7 @@ func testClientDoTimeoutError(t *testing.T, c *Client, n int) { if err == nil { t.Errorf("expecting error") } - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) } } @@ -1920,7 +1920,7 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) { if err == nil { t.Errorf("expecting error") } - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) } if statusCode != 0 { @@ -1939,7 +1939,7 @@ func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) { if err == nil { t.Errorf("expecting error") } - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) } } @@ -2793,7 +2793,7 @@ func TestClientTLSHandshakeTimeout(t *testing.T) { t.Fatal("tlsClientHandshake completed successfully") } - if err != ErrTLSHandshakeTimeout { + if !errors.Is(err, ErrTLSHandshakeTimeout) { t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) } } @@ -2950,7 +2950,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { resp := AcquireResponse() if err := c.Do(req, resp); err != nil { - if err != ErrNoFreeConns { + if !errors.Is(err, ErrNoFreeConns) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns) } errNoFreeConnsCount.Add(1) @@ -3046,7 +3046,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { resp := AcquireResponse() if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil { - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) } errTimeoutCount.Add(1) @@ -3147,10 +3147,10 @@ func TestHostClientErrConnPoolStrategyNotImpl(t *testing.T) { if err := client.Do(req, AcquireResponse()); err != nil { t.Fatalf("unexpected error: %v", err) } - if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl { + if err := client.Do(req, &Response{}); !errors.Is(err, ErrConnPoolStrategyNotImpl) { t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err) } - if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl { + if err := client.Do(req, &Response{}); !errors.Is(err, ErrConnPoolStrategyNotImpl) { t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err) } @@ -3485,7 +3485,7 @@ func TestClientHeadWithBody(t *testing.T) { err = c.Do(req, resp) if err == nil { t.Error("expected timeout error") - } else if err != ErrTimeout { + } else if !errors.Is(err, ErrTimeout) { t.Error(err) } }