Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,9 @@ type HostClient struct {
// and whether to reset the request timeout—should be determined
// based on the return value of this field.
// This field is only effective within the range of MaxIdemponentCallAttempts.
//
// Check errors matches with errors.Is/errors.As, since errors are wrapped with upstream information.
// To get upstream information from the error, check ErrWithUpstream.
RetryIfErr RetryIfErrFunc

connsWait *wantConnQueue
Expand Down Expand Up @@ -1295,7 +1298,7 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
req.timeout = time.Until(deadline)
if req.timeout <= 0 {
return ErrTimeout
return wrapErrWithUpstream(ErrTimeout, c.Addr)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why you only wrap the error here? When someone uses this method they can just read c.Addr themselves? I thought this was about Client.RetryIfErr where you can't know the addr?

}
return c.Do(req, resp)
}
Expand Down Expand Up @@ -1374,12 +1377,15 @@ func (c *HostClient) Do(req *Request, resp *Response) error {
if timeout > 0 {
req.timeout = time.Until(deadline)
if req.timeout <= 0 {
err = ErrTimeout
err = wrapErrWithUpstream(ErrTimeout, c.Addr)
break
}
}

retry, err = c.do(req, resp)

err = wrapErrWithUpstream(err, c.Addr)

if err == nil || !retry {
break
}
Expand All @@ -1393,6 +1399,7 @@ func (c *HostClient) Do(req *Request, resp *Response) error {
if attempts >= maxAttempts {
break
}

if c.RetryIfErr != nil {
resetTimeout, retry = c.RetryIfErr(req, attempts, err)
} else {
Expand Down Expand Up @@ -1530,6 +1537,42 @@ func (e *timeoutError) Timeout() bool {
return true
}

// ErrWithUpstream wraps errors with upstream information where upstream info exists.
// Root error can be obtained via errors.Unwrap. Use errors.Is to check if root error matches.
//
// Should use errors.As to get upstream information from error:
//
// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"}
// err := hc.Do(req, res)
//
// var upstreamErr *fasthttp.ErrWithUpstream
// if errors.As(err, &upstreamErr) {
// upstream = upstreamErr.Upstream // 34.206.39.153:80
Comment thread
erikdubbelboer marked this conversation as resolved.
// }
type ErrWithUpstream struct {
wrapErr error
Upstream string
Comment thread
mdenushev marked this conversation as resolved.
}

func (e *ErrWithUpstream) Error() string {
return fmt.Sprintf("error on upstream %s: %s", e.Upstream, e.wrapErr.Error())
}

func (e *ErrWithUpstream) Unwrap() error {
return e.wrapErr
}

func wrapErrWithUpstream(err error, upstream string) error {
if err == nil {
return nil
}

return &ErrWithUpstream{
wrapErr: err,
Upstream: upstream,
}
}

// ErrTimeout is returned from timed out calls.
var ErrTimeout = &timeoutError{}

Expand Down
30 changes: 15 additions & 15 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestHostClientNegativeTimeout(t *testing.T) {
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); !errors.Is(err, ErrTimeout) {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
ln.Close()
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestDoDeadlineRetry(t *testing.T) {
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*200)); err != ErrTimeout {
if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*200)); !errors.Is(err, ErrTimeout) {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
ln.Close()
Expand Down Expand Up @@ -615,7 +615,7 @@ func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
}

_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != ErrHostClientRedirectToDifferentScheme {
if !errors.Is(err, ErrHostClientRedirectToDifferentScheme) {
t.Fatal("expected HostClient error")
}
}
Expand Down Expand Up @@ -781,7 +781,7 @@ func TestClientReadTimeout(t *testing.T) {
req.SetRequestURI("http://localhost")
req.SetConnectionClose()

if err := c.Do(req, res); err != ErrTimeout {
if err := c.Do(req, res); !errors.Is(err, ErrTimeout) {
t.Errorf("expected ErrTimeout got %#v", err)
}

Expand Down Expand Up @@ -1430,7 +1430,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) {

for {
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err == ErrNoFreeConns {
if errors.Is(err, ErrNoFreeConns) {
time.Sleep(time.Millisecond)
continue
}
Expand Down Expand Up @@ -1715,7 +1715,7 @@ func TestClientFollowRedirects(t *testing.T) {
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
if !errors.Is(err, ErrTimeout) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}

Expand Down Expand Up @@ -1908,7 +1908,7 @@ func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
if !errors.Is(err, ErrTimeout) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
Expand All @@ -1921,7 +1921,7 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
if !errors.Is(err, ErrTimeout) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
if statusCode != 0 {
Expand All @@ -1940,7 +1940,7 @@ func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) {
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
if !errors.Is(err, ErrTimeout) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
Expand Down Expand Up @@ -2794,7 +2794,7 @@ func TestClientTLSHandshakeTimeout(t *testing.T) {
t.Fatal("tlsClientHandshake completed successfully")
}

if err != ErrTLSHandshakeTimeout {
if !errors.Is(err, ErrTLSHandshakeTimeout) {
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
}
Expand Down Expand Up @@ -2953,7 +2953,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
resp := AcquireResponse()

if err := c.Do(req, resp); err != nil {
if err != ErrNoFreeConns {
if !errors.Is(err, ErrNoFreeConns) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns)
}
errNoFreeConnsCount.Add(1)
Expand Down Expand Up @@ -3050,7 +3050,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
resp := AcquireResponse()

if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err != ErrTimeout {
if !errors.Is(err, ErrTimeout) {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
errTimeoutCount.Add(1)
Expand Down Expand Up @@ -3151,10 +3151,10 @@ func TestHostClientErrConnPoolStrategyNotImpl(t *testing.T) {
if err := client.Do(req, AcquireResponse()); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
if err := client.Do(req, &Response{}); !errors.Is(err, ErrConnPoolStrategyNotImpl) {
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
}
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
if err := client.Do(req, &Response{}); !errors.Is(err, ErrConnPoolStrategyNotImpl) {
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
}

Expand Down Expand Up @@ -3489,7 +3489,7 @@ func TestClientHeadWithBody(t *testing.T) {
err = c.Do(req, resp)
if err == nil {
t.Error("expected timeout error")
} else if err != ErrTimeout {
} else if !errors.Is(err, ErrTimeout) {
t.Error(err)
}
}
Expand Down