From 2abc61a55b017c4753e0c60d2918a702ef99b287 Mon Sep 17 00:00:00 2001 From: Jitendra Gangwar Date: Thu, 2 Apr 2026 04:03:20 +0530 Subject: [PATCH] improve Azure SDK VCR replay handling and polling behavior --- sdk/client/client.go | 104 +----- sdk/client/client_test.go | 347 ++++++------------ sdk/client/dataplane/poller_delete.go | 9 +- sdk/client/dataplane/poller_helpers_test.go | 9 + sdk/client/dataplane/poller_lro.go | 7 +- sdk/client/dataplane/poller_lro_test.go | 69 ++++ .../dataplane/poller_provisioning_state.go | 9 +- .../poller_provisioning_state_test.go | 70 ++++ sdk/client/interface.go | 4 +- sdk/client/pollers/interface.go | 1 + sdk/client/pollers/poller.go | 15 +- sdk/client/pollers/poller_test.go | 65 +++- sdk/client/resourcemanager/poller_delete.go | 5 +- .../resourcemanager/poller_helpers_test.go | 9 + sdk/client/resourcemanager/poller_lro.go | 7 +- sdk/client/resourcemanager/poller_lro_test.go | 70 ++++ .../poller_provisioning_state.go | 7 +- .../poller_provisioning_state_test.go | 72 ++++ sdk/client/transport.go | 50 +++ sdk/client/vcr_helper.go | 45 +++ sdk/client/vcr_helper_test.go | 201 ++++++++++ 21 files changed, 798 insertions(+), 377 deletions(-) create mode 100644 sdk/client/transport.go create mode 100644 sdk/client/vcr_helper.go create mode 100644 sdk/client/vcr_helper_test.go diff --git a/sdk/client/client.go b/sdk/client/client.go index c112237ce54..d403ad984c7 100644 --- a/sdk/client/client.go +++ b/sdk/client/client.go @@ -14,12 +14,9 @@ import ( "io" "log" "math" - "net" "net/http" "net/url" - "reflect" "regexp" - "runtime" "strconv" "strings" "time" @@ -325,6 +322,10 @@ type Client struct { // Transport allows overriding the http.RoundTripper used by the client. // When nil, a default transport will be used. Transport http.RoundTripper + + // TransportMode indicates the mode of the transport, this is being used to set vcr transport mode, + // but can be used by custom transports to modify their behavior based on the mode. + TransportMode TransportMode } // NewClient returns a new Client configured with sensible defaults @@ -339,9 +340,10 @@ func NewClient(baseUri string, serviceName, apiVersion string) *Client { } } -// SetTransport configures the transport to be used by the client -func (c *Client) SetTransport(transport http.RoundTripper) { +// SetTransport configures the transport and transport mode to be used by the client. +func (c *Client) SetTransport(transport http.RoundTripper, mode TransportMode) { c.Transport = transport + c.TransportMode = mode } // SetAuthorizer configures the request authorizer for the client @@ -349,70 +351,6 @@ func (c *Client) SetAuthorizer(authorizer auth.Authorizer) { c.Authorizer = authorizer } -// IsVcrReplaying returns true if the provided transport appears to be a VCR recorder in replay mode. -func IsVcrReplaying(transport http.RoundTripper) bool { - if transport == nil { - return false - } - - // We use reflection to avoid a hard dependency on the go-vcr package in the SDK. - // We check for "recorder" in the type name and then attempt to call the Mode() method. - t := reflect.TypeOf(transport) - if !strings.Contains(strings.ToLower(t.String()), "recorder") { - return false - } - - v := reflect.ValueOf(transport) - // If it's a pointer, get the underlying value (MethodByName works on both) - modeMethod := v.MethodByName("Mode") - if !modeMethod.IsValid() { - return false - } - - results := modeMethod.Call(nil) - if len(results) != 1 { - return false - } - - // Mode is normally an int (or an alias of int). - // In go-vcr v1/v2, ModeReplay is 1. - // In go-vcr v3/v4, ModeReplayOnly is 1 and ModeReplayWithNewEpisodes is 2. - // Support for RecordOnce (3) mode: skip only if cassette is NOT new. - if results[0].Kind() == reflect.Int { - mode := results[0].Int() - if mode == 1 || mode == 2 { - return true - } - - if mode == 3 { - // use reflection to get the unexported cassette field - v := reflect.ValueOf(transport) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() == reflect.Struct { - cassetteField := v.FieldByName("cassette") - if cassetteField.IsValid() { - cassette := cassetteField - if cassette.Kind() == reflect.Ptr && !cassette.IsNil() { - cassette = cassette.Elem() - } - - if cassette.Kind() == reflect.Struct { - isNewField := cassette.FieldByName("IsNew") - if isNewField.IsValid() && isNewField.Kind() == reflect.Bool { - // If the cassette is NOT new, we are replaying - return !isNewField.Bool() - } - } - } - } - } - } - - return false -} - // SetUserAgent configures the user agent to be included in requests func (c *Client) SetUserAgent(userAgent string) { c.UserAgent = userAgent @@ -546,6 +484,9 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) { // Check for failed connections etc and decide if retries are appropriate if r == nil { + if IsVCRReplayMissError(err) { + return false, err + } if req.IsIdempotent() { if !isResourceManagerHost(req) { return extendedRetryPolicy(r, err) @@ -595,12 +536,10 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) { resp.Response = r } } - - // If we're running with a VCR transport in REPLAY mode, we mark the response so that the poller knows to skip the delay - if IsVcrReplaying(c.Transport) { - resp.Header.Add("X-Go-Azure-SDK-Skip-Polling-Delay", "true") + // If it's a recorded response, we mark the response so that the poller knows to skip the delay + if c.TransportMode != TransportModeDefault && (IsVCRRecordedResponse(resp.Response) || IsVCRReplaying(c)) { + resp.Header.Set(SkipPollingDelayHeader, "true") } - // Extract OData from response, intentionally ignoring any errors as it's not crucial to extract // valid OData at this point (valid json can still error here, such as any non-object literal) resp.OData, _ = odata.FromResponse(resp.Response) @@ -808,22 +747,7 @@ func (c *Client) retryableClient(ctx context.Context, checkRetry retryablehttp.C if c.Transport != nil { transport = c.Transport } else { - transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - d := &net.Dialer{Resolver: &net.Resolver{}} - return d.DialContext(ctx, network, addr) - }, - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - }, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ForceAttemptHTTP2: true, - MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, - } + transport = GetDefaultHttpTransport() } r.HTTPClient = &http.Client{ diff --git a/sdk/client/client_test.go b/sdk/client/client_test.go index af91981bc5b..1673b2d0397 100644 --- a/sdk/client/client_test.go +++ b/sdk/client/client_test.go @@ -15,6 +15,7 @@ import ( "net/url" "reflect" "testing" + "time" "github.com/hashicorp/go-azure-helpers/lang/pointer" "github.com/hashicorp/go-azure-sdk/sdk/internal/test" @@ -513,260 +514,142 @@ func unmarshalResponse(body io.ReadCloser, unmarshal func(in []byte) error) erro return unmarshal(respBody) } -type roundTripperMock struct { - roundTripFunc func(*http.Request) (*http.Response, error) -} - -func (rt *roundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { - if rt.roundTripFunc != nil { - return rt.roundTripFunc(req) - } - return nil, fmt.Errorf("roundTripFunc missing from mock") -} - -func TestClient_CustomTransport(t *testing.T) { - ctx := context.TODO() - - c := NewClient("http://localhost", "testService", "v1.0") - c.DisableRetries = true - - hitCount := 0 - mockTransport := &roundTripperMock{ - roundTripFunc: func(req *http.Request) (*http.Response, error) { - hitCount++ - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), - Header: map[string][]string{ - "Content-Type": {"application/json"}, - }, - Request: req, - }, nil - }, - } - c.Transport = mockTransport - - reqOpts := RequestOptions{ - ContentType: "application/json", - ExpectedStatusCodes: []int{ - http.StatusOK, - }, - HttpMethod: http.MethodGet, - Path: "/test", - } - - req, err := c.NewRequest(ctx, reqOpts) - if err != nil { - t.Fatalf("NewRequest error: %v", err) - } - - resp, err := req.Execute(ctx) - if err != nil { - t.Fatalf("Execute error: %v", err) - } - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected 200 OK, got %d", resp.StatusCode) - } +func TestClient_VCRCustomTransport(t *testing.T) { - if hitCount != 1 { - t.Errorf("expected transport to be hit 1 time, got %d", hitCount) - } -} - -func TestClient_VCRHeaderInjected(t *testing.T) { - ctx := context.TODO() - - c := NewClient("http://localhost", "testService", "v1.0") - c.DisableRetries = true - - mockTransport := &recorderRecorder{ - mode: 1, // Replay - roundTripFunc: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), - Header: make(http.Header), - Request: req, - }, nil - }, - } - c.Transport = mockTransport - - reqOpts := RequestOptions{ - ContentType: "application/json", - ExpectedStatusCodes: []int{ - http.StatusOK, - }, - HttpMethod: http.MethodGet, - Path: "/test", - } - - req, err := c.NewRequest(ctx, reqOpts) - if err != nil { - t.Fatalf("NewRequest error: %v", err) - } - - resp, err := req.Execute(ctx) - if err != nil { - t.Fatalf("Execute error: %v", err) - } - - if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") != "true" { - t.Errorf("expected skip polling delay header to be injected in replay mode, but wasn't") - } -} - -func TestClient_VCRHeaderInjectedInRecordOnceModeReplaying(t *testing.T) { - ctx := context.TODO() - - c := NewClient("http://localhost", "testService", "v1.0") - c.DisableRetries = true - - mockTransport := &recorderRecorder{ - mode: 3, // RecordOnce - cassette: &testCassette{ - IsNew: false, // REPLAYING - }, - roundTripFunc: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), - Header: make(http.Header), - Request: req, - }, nil + testCases := []struct { + name string + transportMode TransportMode + responseHeaders http.Header + expectedSkipDelayHeader string + expectRecordedResponse bool + expectReplayMode bool + error error + }{ + { + name: "custom transport with default mode", + transportMode: TransportModeDefault, + responseHeaders: http.Header{"Content-Type": []string{"application/json"}}, + expectedSkipDelayHeader: "", + expectRecordedResponse: false, + expectReplayMode: false, }, - } - c.Transport = mockTransport - - reqOpts := RequestOptions{ - ContentType: "application/json", - ExpectedStatusCodes: []int{ - http.StatusOK, + { + name: "custom transport with default mode and vcr replay header true", + transportMode: TransportModeDefault, + responseHeaders: http.Header{"Content-Type": []string{"application/json"}, http.CanonicalHeaderKey(VCRReplayHeader): []string{"true"}}, + expectedSkipDelayHeader: "", + expectRecordedResponse: true, + expectReplayMode: false, }, - HttpMethod: http.MethodGet, - Path: "/test", - } - - req, err := c.NewRequest(ctx, reqOpts) - if err != nil { - t.Fatalf("NewRequest error: %v", err) - } - - resp, err := req.Execute(ctx) - if err != nil { - t.Fatalf("Execute error: %v", err) - } - - if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") != "true" { - t.Errorf("expected skip polling delay header to be injected in RecordOnce replaying mode, but wasn't") - } -} - -func TestClient_VCRHeaderNotInjectedInRecordOnceModeRecording(t *testing.T) { - ctx := context.TODO() - - c := NewClient("http://localhost", "testService", "v1.0") - c.DisableRetries = true - - mockTransport := &recorderRecorder{ - mode: 3, // RecordOnce - cassette: &testCassette{ - IsNew: true, // RECORDING + { + name: "vcr replay mode should set skip delay header true", + transportMode: TransportModeVCRReplay, + responseHeaders: http.Header{"Content-Type": []string{"application/json"}}, + expectedSkipDelayHeader: "true", + expectRecordedResponse: false, + expectReplayMode: true, }, - roundTripFunc: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), - Header: make(http.Header), - Request: req, - }, nil + { + name: "vcr recorded response should set skip delay header true", + transportMode: TransportModeVCRReplayWithNewEpisodes, + responseHeaders: http.Header{"Content-Type": []string{"application/json"}, http.CanonicalHeaderKey(VCRReplayHeader): []string{"true"}}, + expectedSkipDelayHeader: "true", + expectRecordedResponse: true, + expectReplayMode: false, }, - } - c.Transport = mockTransport - - reqOpts := RequestOptions{ - ContentType: "application/json", - ExpectedStatusCodes: []int{ - http.StatusOK, + { + name: "vcr replay miss error,should skip retries", + transportMode: TransportModeVCRReplay, + expectedSkipDelayHeader: "", + expectRecordedResponse: false, + expectReplayMode: true, + error: fmt.Errorf(VCRInteractionNotFoundErrMsg), }, - HttpMethod: http.MethodGet, - Path: "/test", } - req, err := c.NewRequest(ctx, reqOpts) - if err != nil { - t.Fatalf("NewRequest error: %v", err) - } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c := NewClient("https://management.azure.com.example", "testService", "v1.0") + + hitCount := 0 + mockTransport := &roundTripperMock{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + hitCount++ + if testCase.error != nil { + return nil, testCase.error + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), + Header: testCase.responseHeaders.Clone(), + Request: req, + }, nil + }, + } + c.SetTransport(mockTransport, testCase.transportMode) - resp, err := req.Execute(ctx) - if err != nil { - t.Fatalf("Execute error: %v", err) - } + reqOpts := RequestOptions{ + ContentType: "application/json", + ExpectedStatusCodes: []int{ + http.StatusOK, + }, + HttpMethod: http.MethodGet, + Path: "/test", + } - if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") == "true" { - t.Errorf("expected skip polling delay header NOT to be injected in RecordOnce recording mode, but it was") - } -} + req, err := c.NewRequest(ctx, reqOpts) + if err != nil { + t.Fatalf("NewRequest error: %v", err) + } -func TestClient_VCRHeaderNotInjectedInRecordMode(t *testing.T) { - ctx := context.TODO() + resp, err := req.Execute(ctx) - c := NewClient("http://localhost", "testService", "v1.0") - c.DisableRetries = true - - mockTransport := &recorderRecorder{ - mode: 0, // Record - roundTripFunc: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))), - Header: make(http.Header), - Request: req, - }, nil - }, - } - c.Transport = mockTransport + if hitCount != 1 { + t.Fatalf("expected transport to be hit 1 time, got %d", hitCount) + } + if c.TransportMode != testCase.transportMode { + t.Fatalf("expected transport mode to be %q, got %q", testCase.transportMode, c.TransportMode) + } + if testCase.expectReplayMode != IsVCRReplaying(c) { + t.Fatalf("expected IsVCRReplaying=%t, got %t", testCase.expectReplayMode, IsVCRReplaying(c)) + } + if testCase.error != nil { + if err == nil { + t.Fatal("expected Execute to fail with a replay miss") + } + if !IsVCRReplayMissError(err) { + t.Fatalf("expected a replay miss error, got: %v", err) + } + return + } + if err != nil { + t.Fatalf("Execute error: %v", err) + } - reqOpts := RequestOptions{ - ContentType: "application/json", - ExpectedStatusCodes: []int{ - http.StatusOK, - }, - HttpMethod: http.MethodGet, - Path: "/test", - } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 OK, got %d", resp.StatusCode) + } - req, err := c.NewRequest(ctx, reqOpts) - if err != nil { - t.Fatalf("NewRequest error: %v", err) - } + if testCase.expectedSkipDelayHeader != resp.Header.Get(SkipPollingDelayHeader) { + t.Fatalf("expected %s header to be %q, got %q", SkipPollingDelayHeader, testCase.expectedSkipDelayHeader, resp.Header.Get(SkipPollingDelayHeader)) + } - resp, err := req.Execute(ctx) - if err != nil { - t.Fatalf("Execute error: %v", err) - } + if testCase.expectRecordedResponse != IsVCRRecordedResponse(resp.Response) { + t.Fatalf("expected IsVCRRecordedResponse=%t, got %t", testCase.expectRecordedResponse, IsVCRRecordedResponse(resp.Response)) + } - if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") == "true" { - t.Errorf("expected skip polling delay header NOT to be injected in record mode, but it was") + }) } } -type recorderRecorder struct { +type roundTripperMock struct { roundTripFunc func(*http.Request) (*http.Response, error) - mode int - cassette *testCassette -} - -type testCassette struct { - IsNew bool -} - -func (rt *recorderRecorder) Mode() int { - return rt.mode } -func (rt *recorderRecorder) RoundTrip(req *http.Request) (*http.Response, error) { +func (rt *roundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { if rt.roundTripFunc != nil { return rt.roundTripFunc(req) } diff --git a/sdk/client/dataplane/poller_delete.go b/sdk/client/dataplane/poller_delete.go index a93c7399373..9d685fbb4fe 100644 --- a/sdk/client/dataplane/poller_delete.go +++ b/sdk/client/dataplane/poller_delete.go @@ -71,11 +71,11 @@ func (p deletePoller) Poll(ctx context.Context) (result *pollers.PollResult, err } req, err := p.client.NewRequest(ctx, opts) if err != nil { - return nil, fmt.Errorf("building request: %+v", err) + return nil, fmt.Errorf("building request: %w", err) } resp, err := p.client.Execute(ctx, req) if err != nil { - return nil, fmt.Errorf("executing request: %+v", err) + return nil, fmt.Errorf("executing request: %w", err) } if resp == nil { return nil, pollers.PollingDroppedConnectionError{} @@ -107,10 +107,7 @@ func (p deletePoller) Poll(ctx context.Context) (result *pollers.PollResult, err } func (p deletePoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return p.client != nil && client.IsVCRReplaying(p.client.Client) } var _ client.Options = deleteOptions{} diff --git a/sdk/client/dataplane/poller_helpers_test.go b/sdk/client/dataplane/poller_helpers_test.go index 8fd22f77f1a..05c3ad47dc3 100644 --- a/sdk/client/dataplane/poller_helpers_test.go +++ b/sdk/client/dataplane/poller_helpers_test.go @@ -4,6 +4,7 @@ package dataplane import ( + "fmt" "net/http" "testing" @@ -61,3 +62,11 @@ func dropConnection(t *testing.T, w http.ResponseWriter) { } conn.Close() } + +type errRoundTripper struct { + errMsg string +} + +func (r errRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("%s", r.errMsg) +} diff --git a/sdk/client/dataplane/poller_lro.go b/sdk/client/dataplane/poller_lro.go index 7cf8c4a1aae..8144b9c04a3 100644 --- a/sdk/client/dataplane/poller_lro.go +++ b/sdk/client/dataplane/poller_lro.go @@ -113,7 +113,7 @@ func (p *longRunningOperationPoller) Poll(ctx context.Context) (result *pollers. result.HttpResponse, err = req.Execute(ctx) if err != nil { var e *url.Error - if errors.As(err, &e) { + if errors.As(err, &e) && !client.IsVCRReplayMissError(e) { p.droppedConnectionCount++ if p.droppedConnectionCount < p.maxDroppedConnections { result.Status = pollers.PollingStatusUnknown @@ -215,10 +215,7 @@ func (p *longRunningOperationPoller) Poll(ctx context.Context) (result *pollers. } func (p *longRunningOperationPoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return client.IsVCRReplaying(p.client) } type operationResult struct { diff --git a/sdk/client/dataplane/poller_lro_test.go b/sdk/client/dataplane/poller_lro_test.go index d25ff3eddae..a01e203f00b 100644 --- a/sdk/client/dataplane/poller_lro_test.go +++ b/sdk/client/dataplane/poller_lro_test.go @@ -7,7 +7,9 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/hashicorp/go-azure-sdk/sdk/client" "github.com/hashicorp/go-azure-sdk/sdk/client/pollers" @@ -553,3 +555,70 @@ func TestPollerLRO_InStatus_404ThenInProgressThenSucceeded(t *testing.T) { // sanity-checking helpers.assertCalled(t, 4) } + +func TestPollerLRO_VCRErrorHandling(t *testing.T) { + testCases := []struct { + name string + errMsg string + expectReplayMiss bool + expectResultStatus *pollers.PollingStatus + expectError bool + expectDroppedConnections int + }{ + { + name: "replay miss url error", + errMsg: client.VCRInteractionNotFoundErrMsg, + expectReplayMiss: true, + expectError: true, + expectDroppedConnections: 0, + }, + { + name: "non retryable url error", + errMsg: "unsupported protocol scheme", + expectResultStatus: func() *pollers.PollingStatus { s := pollers.PollingStatusUnknown; return &s }(), + expectDroppedConnections: 1, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + helpers := newLongRunningOperationsEndpoint([]expectedResponse{}) + response := &client.Response{ + Response: helpers.response(), + } + + baseClient := client.NewClient("https://management.azure.com", "Example", "2020-01-01") + baseClient.SetTransport(errRoundTripper{errMsg: testCase.errMsg}, client.TransportModeVCRReplay) + + poller, err := longRunningOperationPollerFromResponse(response, baseClient) + if err != nil { + t.Fatal(err.Error()) + } + + result, err := poller.Poll(ctx) + if testCase.expectError { + if err == nil { + t.Fatal("expected polling to return an error, but got no error") + } + if result != nil { + t.Fatalf("expected no poll result, got: %+v ,error %T (err=%v)", result, err, err) + } + if testCase.expectReplayMiss != client.IsVCRReplayMissError(err) { + t.Fatalf("expected IsVCRReplayMissError=%t, got %t (err=%v)", testCase.expectReplayMiss, client.IsVCRReplayMissError(err), err) + } + if !strings.Contains(err.Error(), testCase.errMsg) { + t.Fatalf("expected error to contain %q, got %v", testCase.errMsg, err) + } + } + if testCase.expectResultStatus != nil && (result == nil || result.Status != *testCase.expectResultStatus) { + t.Fatalf("expected poll result with status %q, got: %+v", *testCase.expectResultStatus, result) + } + if poller.droppedConnectionCount != testCase.expectDroppedConnections { + t.Fatalf("expected droppedConnectionCount to be %d, got %d", testCase.expectDroppedConnections, poller.droppedConnectionCount) + } + }) + } +} diff --git a/sdk/client/dataplane/poller_provisioning_state.go b/sdk/client/dataplane/poller_provisioning_state.go index 096bc52c522..7ae0eeaa0ae 100644 --- a/sdk/client/dataplane/poller_provisioning_state.go +++ b/sdk/client/dataplane/poller_provisioning_state.go @@ -90,7 +90,7 @@ func (p *provisioningStatePoller) Poll(ctx context.Context) (*pollers.PollResult resp, err := p.client.Execute(ctx, req) if err != nil { var e *url.Error - if errors.As(err, &e) { + if errors.As(err, &e) && !client.IsVCRReplayMissError(err) { p.droppedConnectionCount++ if p.droppedConnectionCount < p.maxDroppedConnections { return &pollers.PollResult{ @@ -100,7 +100,7 @@ func (p *provisioningStatePoller) Poll(ctx context.Context) (*pollers.PollResult } } - return nil, fmt.Errorf("executing request: %+v", err) + return nil, fmt.Errorf("executing request: %w", err) } p.droppedConnectionCount = 0 @@ -160,10 +160,7 @@ func (p *provisioningStatePoller) Poll(ctx context.Context) (*pollers.PollResult } func (p *provisioningStatePoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return p.client != nil && client.IsVCRReplaying(p.client.Client) } type provisioningStateResult struct { diff --git a/sdk/client/dataplane/poller_provisioning_state_test.go b/sdk/client/dataplane/poller_provisioning_state_test.go index 846afe0c56a..86f535e9a1f 100644 --- a/sdk/client/dataplane/poller_provisioning_state_test.go +++ b/sdk/client/dataplane/poller_provisioning_state_test.go @@ -7,6 +7,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -263,3 +264,72 @@ func TestPollerProvisioningState_InStatus_Poll(t *testing.T) { } } } + +func TestPollerProvisioningState_VCRErrorHandling(t *testing.T) { + testCases := []struct { + name string + errMsg string + expectReplayMiss bool + expectResultStatus *pollers.PollingStatus + expectError bool + expectDroppedConnections int + }{ + { + name: "replay miss url error", + errMsg: client.VCRInteractionNotFoundErrMsg, + expectReplayMiss: true, + expectError: true, + expectDroppedConnections: 0, + }, + { + name: "non retryable url error", + errMsg: "unsupported protocol scheme", + expectResultStatus: func() *pollers.PollingStatus { s := pollers.PollingStatusUnknown; return &s }(), + expectDroppedConnections: 1, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + dataplaneClient := &Client{ + Client: client.NewClient("https://management.azure.com", "Example", "2020-01-01"), + ApiVersion: "2020-01-01", + } + dataplaneClient.SetTransport(errRoundTripper{errMsg: testCase.errMsg}, client.TransportModeVCRReplay) + + poller := provisioningStatePoller{ + apiVersion: "2020-01-01", + client: dataplaneClient, + initialRetryDuration: 10 * time.Millisecond, + originalUri: "/provisioning-state/poll", + resourcePath: "/provisioning-state/poll", + maxDroppedConnections: 3, + } + + result, err := poller.Poll(ctx) + if testCase.expectError { + if err == nil { + t.Fatal("expected polling to return an error, but got no error") + } + if result != nil { + t.Fatalf("expected no poll result, got: %+v ,error %T (err=%v)", result, err, err) + } + if testCase.expectReplayMiss != client.IsVCRReplayMissError(err) { + t.Fatalf("expected IsVCRReplayMissError=%t, got %t (err=%v)", testCase.expectReplayMiss, client.IsVCRReplayMissError(err), err) + } + if !strings.Contains(err.Error(), testCase.errMsg) { + t.Fatalf("expected error to contain %q, got %v", testCase.errMsg, err) + } + } + if testCase.expectResultStatus != nil && (result == nil || result.Status != *testCase.expectResultStatus) { + t.Fatalf("expected poll result with status %q, got: %+v", *testCase.expectResultStatus, result) + } + if poller.droppedConnectionCount != testCase.expectDroppedConnections { + t.Fatalf("expected droppedConnectionCount to be %d, got %d", testCase.expectDroppedConnections, poller.droppedConnectionCount) + } + }) + } +} diff --git a/sdk/client/interface.go b/sdk/client/interface.go index fa3575d9ec1..f67f09b7326 100644 --- a/sdk/client/interface.go +++ b/sdk/client/interface.go @@ -42,8 +42,8 @@ type BaseClient interface { // ClearResponseMiddlewares removes all response middleware functions for the client ClearResponseMiddlewares() - // SetTransport configures the transport to be used by the client - SetTransport(http.RoundTripper) + // SetTransport configures the transport and transport mode to be used by the client + SetTransport(http.RoundTripper, TransportMode) } // RequestRetryFunc is a function that determines whether an HTTP request has failed due to eventual consistency and should be retried diff --git a/sdk/client/pollers/interface.go b/sdk/client/pollers/interface.go index cc476408ebb..9d2349d9196 100644 --- a/sdk/client/pollers/interface.go +++ b/sdk/client/pollers/interface.go @@ -27,6 +27,7 @@ type PollerType interface { Poll(ctx context.Context) (*PollResult, error) } +// delaySkipper is consulted only when there is no HTTP response available to inspect. type delaySkipper interface { SkipDelay() bool } diff --git a/sdk/client/pollers/poller.go b/sdk/client/pollers/poller.go index d3ce9c6f100..88aa00edc3a 100644 --- a/sdk/client/pollers/poller.go +++ b/sdk/client/pollers/poller.go @@ -7,7 +7,7 @@ import ( "context" "errors" "fmt" - "os" + "strings" "sync" "time" @@ -172,7 +172,7 @@ func (p *Poller) PollUntilDone(ctx context.Context) error { } if p.latestError != nil { - if !pointer.From(p.retryOnError) { + if !pointer.From(p.retryOnError) || client.IsVCRReplayMissErrorDeprecated(p.latestError) { break } } @@ -255,19 +255,12 @@ func (p *Poller) skipPollingDelay(ctx context.Context) bool { return true } - if os.Getenv("GO_AZURE_SDK_SKIP_POLLING_DELAY") == "true" { - return true + if p.latestResponse != nil && p.latestResponse.HttpResponse != nil && p.latestResponse.HttpResponse.Response != nil { + return strings.EqualFold(p.latestResponse.HttpResponse.Header.Get(client.SkipPollingDelayHeader), "true") } if skipper, ok := p.poller.(delaySkipper); ok && skipper.SkipDelay() { return true } - - if p.latestResponse != nil && p.latestResponse.HttpResponse != nil && p.latestResponse.HttpResponse.Response != nil && p.latestResponse.HttpResponse.Header != nil { - if p.latestResponse.HttpResponse.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") == "true" { - return true - } - } - return false } diff --git a/sdk/client/pollers/poller_test.go b/sdk/client/pollers/poller_test.go index 5319571440b..15abbee82a3 100644 --- a/sdk/client/pollers/poller_test.go +++ b/sdk/client/pollers/poller_test.go @@ -7,7 +7,7 @@ import ( "context" "fmt" "net/http" - "os" + "net/url" "testing" "time" @@ -160,9 +160,17 @@ func TestPoller_SkipsDelayWhenContextFlagSet(t *testing.T) { } } -func TestPoller_SkipsDelayWhenEnvVarSet(t *testing.T) { +func TestPoller_SkipsDelayWhenHeaderSet(t *testing.T) { + expectedResponse := &client.Response{ + Response: &http.Response{ + Header: make(http.Header), + }, + } + expectedResponse.Header.Set(client.SkipPollingDelayHeader, "true") + pollerType := fakePollerWithResults([]pollResult{ pollers.PollResult{ + HttpResponse: expectedResponse, PollInterval: 1 * time.Hour, Status: pollers.PollingStatusInProgress, }, @@ -172,9 +180,6 @@ func TestPoller_SkipsDelayWhenEnvVarSet(t *testing.T) { }) poller := pollers.NewPoller(pollerType, 10*time.Millisecond, pollers.DefaultNumberOfDroppedConnectionsToAllow) - os.Setenv("GO_AZURE_SDK_SKIP_POLLING_DELAY", "true") - defer os.Unsetenv("GO_AZURE_SDK_SKIP_POLLING_DELAY") - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -189,27 +194,65 @@ func TestPoller_SkipsDelayWhenEnvVarSet(t *testing.T) { } } -func TestPoller_SkipsDelayWhenHeaderSet(t *testing.T) { +func TestPoller_RetryOnError_DoesNotRetryWrappedVCRReplayMiss(t *testing.T) { + pollerType := fakePollerWithResults([]pollResult{ + errorResult{ + Error: fmt.Errorf("executing request: %+v", &url.Error{ + Op: http.MethodGet, + URL: "https://example.test/operations/1", + Err: fmt.Errorf(client.VCRInteractionNotFoundErrMsg), + }), + }, + pollers.PollResult{ + Status: pollers.PollingStatusSucceeded, + }, + }) + pollerType.skipDelay = true + poller := pollers.NewRetryOnErrorPoller(pollerType, 10*time.Millisecond, pollers.DefaultNumberOfDroppedConnectionsToAllow, true) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := poller.PollUntilDone(ctx) + if err == nil { + t.Fatal("expected replay miss to fail polling immediately") + } + if !client.IsVCRReplayMissErrorDeprecated(err) { + t.Fatalf("expected replay miss error, got: %T,%v", err, err) + } + if pollerType.count != 1 { + t.Fatalf("expected the fakePoller to be called 1 time but got %d", pollerType.count) + } + if poller.LatestStatus() != pollers.PollingStatusUnknown { + t.Fatalf("expected LatestStatus to be Unknown but got %q", string(poller.LatestStatus())) + } + if poller.LatestResponse() != nil { + t.Fatalf("expected LatestResponse to be nil but got: %+v", poller.LatestResponse()) + } +} + +func TestPoller_DoesNotSkipDelayWhenHTTPResponsePresentWithoutHeader(t *testing.T) { expectedResponse := &client.Response{ Response: &http.Response{ Header: make(http.Header), }, } - expectedResponse.Header.Set("X-Go-Azure-SDK-Skip-Polling-Delay", "true") + pollInterval := 80 * time.Millisecond pollerType := fakePollerWithResults([]pollResult{ pollers.PollResult{ HttpResponse: expectedResponse, - PollInterval: 1 * time.Hour, + PollInterval: pollInterval, Status: pollers.PollingStatusInProgress, }, pollers.PollResult{ Status: pollers.PollingStatusSucceeded, }, }) + pollerType.skipDelay = true poller := pollers.NewPoller(pollerType, 10*time.Millisecond, pollers.DefaultNumberOfDroppedConnectionsToAllow) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() start := time.Now() @@ -218,8 +261,8 @@ func TestPoller_SkipsDelayWhenHeaderSet(t *testing.T) { } duration := time.Since(start) - if duration > 100*time.Millisecond { - t.Fatalf("expected polling to finish within 100ms but took %v", duration) + if duration < pollInterval { + t.Fatalf("expected polling to wait for the poll interval when no skip header is present, but it finished in %v", duration) } } diff --git a/sdk/client/resourcemanager/poller_delete.go b/sdk/client/resourcemanager/poller_delete.go index 0467bd020ea..c34f9a59ef8 100644 --- a/sdk/client/resourcemanager/poller_delete.go +++ b/sdk/client/resourcemanager/poller_delete.go @@ -107,10 +107,7 @@ func (p deletePoller) Poll(ctx context.Context) (result *pollers.PollResult, err } func (p deletePoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return p.client != nil && client.IsVCRReplaying(p.client.Client) } var _ client.Options = deleteOptions{} diff --git a/sdk/client/resourcemanager/poller_helpers_test.go b/sdk/client/resourcemanager/poller_helpers_test.go index e7bdbf1b00e..e45238b1a68 100644 --- a/sdk/client/resourcemanager/poller_helpers_test.go +++ b/sdk/client/resourcemanager/poller_helpers_test.go @@ -4,6 +4,7 @@ package resourcemanager import ( + "fmt" "net/http" "testing" @@ -61,3 +62,11 @@ func dropConnection(t *testing.T, w http.ResponseWriter) { } conn.Close() } + +type errRoundTripper struct { + errMsg string +} + +func (r errRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("%s", r.errMsg) +} diff --git a/sdk/client/resourcemanager/poller_lro.go b/sdk/client/resourcemanager/poller_lro.go index 5a43171aadb..aae972b2d28 100644 --- a/sdk/client/resourcemanager/poller_lro.go +++ b/sdk/client/resourcemanager/poller_lro.go @@ -113,7 +113,7 @@ func (p *longRunningOperationPoller) Poll(ctx context.Context) (result *pollers. result.HttpResponse, err = req.Execute(ctx) if err != nil { var e *url.Error - if errors.As(err, &e) { + if errors.As(err, &e) && !client.IsVCRReplayMissError(e) { p.droppedConnectionCount++ if p.droppedConnectionCount < p.maxDroppedConnections { result.Status = pollers.PollingStatusUnknown @@ -215,10 +215,7 @@ func (p *longRunningOperationPoller) Poll(ctx context.Context) (result *pollers. } func (p *longRunningOperationPoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return client.IsVCRReplaying(p.client) } type operationResult struct { diff --git a/sdk/client/resourcemanager/poller_lro_test.go b/sdk/client/resourcemanager/poller_lro_test.go index aaec8d6162b..d00dd5ae00a 100644 --- a/sdk/client/resourcemanager/poller_lro_test.go +++ b/sdk/client/resourcemanager/poller_lro_test.go @@ -7,7 +7,9 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/hashicorp/go-azure-sdk/sdk/client" "github.com/hashicorp/go-azure-sdk/sdk/client/pollers" @@ -553,3 +555,71 @@ func TestPollerLRO_InStatus_404ThenInProgressThenSucceeded(t *testing.T) { // sanity-checking helpers.assertCalled(t, 4) } + +func TestPollerLRO_VCRErrorHandling(t *testing.T) { + + testCases := []struct { + name string + errMsg string + expectReplayMiss bool + expectResultStatus *pollers.PollingStatus + expectError bool + expectDroppedConnections int + }{ + { + name: "replay miss url error", + errMsg: client.VCRInteractionNotFoundErrMsg, + expectReplayMiss: true, + expectError: true, + expectDroppedConnections: 0, + }, + { + name: "non retryable url error", + errMsg: "unsupported protocol scheme", + expectResultStatus: func() *pollers.PollingStatus { s := pollers.PollingStatusUnknown; return &s }(), + expectDroppedConnections: 1, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + helpers := newLongRunningOperationsEndpoint([]expectedResponse{}) + response := &client.Response{ + Response: helpers.response(), + } + + baseClient := client.NewClient("https://management.azure.com", "Example", "2020-01-01") + baseClient.SetTransport(errRoundTripper{errMsg: testCase.errMsg}, client.TransportModeVCRReplay) + + poller, err := longRunningOperationPollerFromResponse(response, baseClient) + if err != nil { + t.Fatal(err.Error()) + } + + result, err := poller.Poll(ctx) + if testCase.expectError { + if err == nil { + t.Fatal("expected polling to return an error, but got no error") + } + if result != nil { + t.Fatalf("expected no poll result, got: %+v ,error %T (err=%v)", result, err, err) + } + if testCase.expectReplayMiss != client.IsVCRReplayMissError(err) { + t.Fatalf("expected IsVCRReplayMissError=%t, got %t (err=%v)", testCase.expectReplayMiss, client.IsVCRReplayMissError(err), err) + } + if !strings.Contains(err.Error(), testCase.errMsg) { + t.Fatalf("expected error to contain %q, got %v", testCase.errMsg, err) + } + } + if testCase.expectResultStatus != nil && (result == nil || result.Status != *testCase.expectResultStatus) { + t.Fatalf("expected poll result with status %q, got: %+v", *testCase.expectResultStatus, result) + } + if poller.droppedConnectionCount != testCase.expectDroppedConnections { + t.Fatalf("expected droppedConnectionCount to be %d, got %d", testCase.expectDroppedConnections, poller.droppedConnectionCount) + } + + }) + } +} diff --git a/sdk/client/resourcemanager/poller_provisioning_state.go b/sdk/client/resourcemanager/poller_provisioning_state.go index cc581c28fd8..6c6d163c101 100644 --- a/sdk/client/resourcemanager/poller_provisioning_state.go +++ b/sdk/client/resourcemanager/poller_provisioning_state.go @@ -90,7 +90,7 @@ func (p *provisioningStatePoller) Poll(ctx context.Context) (*pollers.PollResult resp, err := p.client.Execute(ctx, req) if err != nil { var e *url.Error - if errors.As(err, &e) { + if errors.As(err, &e) && !client.IsVCRReplayMissError(e) { p.droppedConnectionCount++ if p.droppedConnectionCount < p.maxDroppedConnections { return &pollers.PollResult{ @@ -160,10 +160,7 @@ func (p *provisioningStatePoller) Poll(ctx context.Context) (*pollers.PollResult } func (p *provisioningStatePoller) SkipDelay() bool { - if p.client != nil { - return client.IsVcrReplaying(p.client.Transport) - } - return false + return p.client != nil && client.IsVCRReplaying(p.client.Client) } type provisioningStateResult struct { diff --git a/sdk/client/resourcemanager/poller_provisioning_state_test.go b/sdk/client/resourcemanager/poller_provisioning_state_test.go index a260b13a249..ca3d99877bc 100644 --- a/sdk/client/resourcemanager/poller_provisioning_state_test.go +++ b/sdk/client/resourcemanager/poller_provisioning_state_test.go @@ -7,6 +7,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -263,3 +264,74 @@ func TestPollerProvisioningState_InStatus_Poll(t *testing.T) { } } } + +func TestPollerProvisioningState_VCRErrorHandling(t *testing.T) { + + testCases := []struct { + name string + errMsg string + expectReplayMiss bool + expectResultStatus *pollers.PollingStatus + expectError bool + expectDroppedConnections int + }{ + { + name: "replay miss url error", + errMsg: client.VCRInteractionNotFoundErrMsg, + expectReplayMiss: true, + expectError: true, + expectDroppedConnections: 0, + }, + { + name: "non retryable url error", + errMsg: "unsupported protocol scheme", + expectResultStatus: func() *pollers.PollingStatus { s := pollers.PollingStatusUnknown; return &s }(), + expectDroppedConnections: 1, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + resourceManagerClient := &Client{ + Client: client.NewClient("https://management.azure.com", "Example", "2020-01-01"), + apiVersion: "2020-01-01", + } + resourceManagerClient.SetTransport(errRoundTripper{errMsg: testCase.errMsg}, client.TransportModeVCRReplay) + + poller := provisioningStatePoller{ + apiVersion: "2020-01-01", + client: resourceManagerClient, + initialRetryDuration: 10 * time.Millisecond, + originalUri: "/provisioning-state/poll", + resourcePath: "/provisioning-state/poll", + maxDroppedConnections: 3, + } + + result, err := poller.Poll(ctx) + if testCase.expectError { + if err == nil { + t.Fatal("expected polling to return an error, but got no error") + } + if result != nil { + t.Fatalf("expected no poll result, got: %+v ,error %T (err=%v)", result, err, err) + } + if testCase.expectReplayMiss != client.IsVCRReplayMissErrorDeprecated(err) { + t.Fatalf("expected IsVCRReplayMissError=%t, got %t (err=%v)", testCase.expectReplayMiss, client.IsVCRReplayMissErrorDeprecated(err), err) + } + if !strings.Contains(err.Error(), testCase.errMsg) { + t.Fatalf("expected error to contain %q, got %v", testCase.errMsg, err) + } + } + if testCase.expectResultStatus != nil && (result == nil || result.Status != *testCase.expectResultStatus) { + t.Fatalf("expected poll result with status %q, got: %+v", *testCase.expectResultStatus, result) + } + if poller.droppedConnectionCount != testCase.expectDroppedConnections { + t.Fatalf("expected droppedConnectionCount to be %d, got %d", testCase.expectDroppedConnections, poller.droppedConnectionCount) + } + + }) + } +} diff --git a/sdk/client/transport.go b/sdk/client/transport.go new file mode 100644 index 00000000000..a144a915f17 --- /dev/null +++ b/sdk/client/transport.go @@ -0,0 +1,50 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package client + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "runtime" + "time" +) + +// TransportMode describes how the configured transport is being used. +type TransportMode string + +const ( + // TransportModeDefault indicates that the transport is being used in the default mode, without recording or replaying traffic. + TransportModeDefault TransportMode = "" + + // TransportModeVCRRecord indicates that the transport is recording live traffic. + TransportModeVCRRecord TransportMode = "vcr_record" + + // TransportModeVCRReplay indicates that the transport is replaying previously recorded traffic. + TransportModeVCRReplay TransportMode = "vcr_replay" + + // TransportModeVCRReplayWithNewEpisodes indicates that the transport is replaying previously recorded traffic, but will record new interactions if a replay miss occurs. + TransportModeVCRReplayWithNewEpisodes TransportMode = "vcr_replay_with_new_episodes" +) + +// GetDefaultHttpTransport returns a new default transport configured for SDK HTTP traffic. +func GetDefaultHttpTransport() http.RoundTripper { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + d := &net.Dialer{Resolver: &net.Resolver{}} + return d.DialContext(ctx, network, addr) + }, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, + } +} diff --git a/sdk/client/vcr_helper.go b/sdk/client/vcr_helper.go new file mode 100644 index 00000000000..a31a978817f --- /dev/null +++ b/sdk/client/vcr_helper.go @@ -0,0 +1,45 @@ +package client + +import ( + "errors" + "net/http" + "net/url" + "strings" +) + +const ( + // VCRReplayHeader is set by VCR-aware transports on replayed HTTP responses. + VCRReplayHeader = "X-Go-Azure-SDK-VCR-Replay" + + SkipPollingDelayHeader = "X-Go-Azure-SDK-Skip-Polling-Delay" + + VCRInteractionNotFoundErrMsg = "requested interaction not found" +) + +// The VCR recorder sets the X-Go-Azure-SDK-VCR-Replay header to "true" when a response is returned from a cassette. +// In ModeReplayWithNewEpisodes, VCR may still make live HTTP requests, so we check for the presence of this header +// instead of relying only on the VCR mode. +func IsVCRRecordedResponse(resp *http.Response) bool { + return resp != nil && resp.Header != nil && + strings.EqualFold(resp.Header.Get(VCRReplayHeader), "true") +} + +// IsVCRReplayMissError returns true when the error indicates that a replayed cassette has no matching interaction. +func IsVCRReplayMissError(err error) bool { + var urlErr *url.Error + if err == nil || !errors.As(err, &urlErr) { + return false + } + return strings.Contains(strings.ToLower(urlErr.Error()), VCRInteractionNotFoundErrMsg) +} + +// Deprecated: Use IsVCRReplayMissError instead. This is a temporary fallback, as pollers wrap errors using %+v, +// which loses the original error type information. Once pollers are updated to use %w (which preserves type information), +// this function will be removed. +func IsVCRReplayMissErrorDeprecated(err error) bool { + return err != nil && strings.Contains(strings.ToLower(err.Error()), VCRInteractionNotFoundErrMsg) +} + +func IsVCRReplaying(c *Client) bool { + return c != nil && c.TransportMode == TransportModeVCRReplay +} diff --git a/sdk/client/vcr_helper_test.go b/sdk/client/vcr_helper_test.go new file mode 100644 index 00000000000..fe6767515d4 --- /dev/null +++ b/sdk/client/vcr_helper_test.go @@ -0,0 +1,201 @@ +package client + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "testing" +) + +func TestIsVCRRecordedResponse(t *testing.T) { + testCases := []struct { + name string + response *http.Response + expected bool + }{ + { + name: "nil response", + response: nil, + expected: false, + }, + { + name: "nil headers", + response: &http.Response{ + Header: nil, + }, + expected: false, + }, + { + name: "header false", + response: &http.Response{ + Header: http.Header{ + "X-Azure-SDK-VCR-Test": []string{"true"}, + }, + }, + expected: false, + }, + { + name: "header true", + response: &http.Response{ + Header: http.Header{ + http.CanonicalHeaderKey(VCRReplayHeader): []string{"true"}, + }, + }, + expected: true, + }, + { + name: "header true mixed case", + response: &http.Response{ + Header: http.Header{ + http.CanonicalHeaderKey(VCRReplayHeader): []string{"TrUe"}, + }, + }, + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual := IsVCRRecordedResponse(testCase.response) + if actual != testCase.expected { + t.Fatalf("expected %t, got %t", testCase.expected, actual) + } + }) + } +} + +func TestIsVCRReplayMissError(t *testing.T) { + testCases := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "plain error", + err: errors.New(VCRInteractionNotFoundErrMsg), + expected: false, + }, + { + name: "url error replay miss", + err: &url.Error{ + Op: http.MethodGet, + URL: "https://example.test/resource", + Err: errors.New(VCRInteractionNotFoundErrMsg), + }, + expected: true, + }, + { + name: "wrapped url error replay miss", + err: fmt.Errorf("wrapped: %w", &url.Error{ + Op: http.MethodGet, + URL: "https://example.test/resource", + Err: errors.New("Requested Interaction Not Found"), + }), + expected: true, + }, + { + name: "url error different message", + err: &url.Error{ + Op: http.MethodGet, + URL: "https://example.test/resource", + Err: errors.New("dial tcp timeout"), + }, + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual := IsVCRReplayMissError(testCase.err) + if actual != testCase.expected { + t.Fatalf("expected %t, got %t (err=%v)", testCase.expected, actual, testCase.err) + } + }) + } +} + +func TestIsVCRReplayMissErrorDeprecated(t *testing.T) { + testCases := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "wrapped replay miss message", + err: fmt.Errorf("polling failed: %+v", errors.New(VCRInteractionNotFoundErrMsg)), + expected: true, + }, + { + name: "url error replay miss", + err: &url.Error{ + Op: http.MethodGet, + URL: "https://example.test/resource", + Err: errors.New(VCRInteractionNotFoundErrMsg), + }, + expected: true, + }, + { + name: "different error", + err: errors.New("unsupported protocol scheme"), + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual := IsVCRReplayMissErrorDeprecated(testCase.err) + if actual != testCase.expected { + t.Fatalf("expected %t, got %t (err=%v)", testCase.expected, actual, testCase.err) + } + }) + } +} + +func TestIsVCRReplaying(t *testing.T) { + testCases := []struct { + name string + client *Client + expected bool + }{ + + { + name: "default transport mode", + client: NewClient("https://localhost", "Example", "2020-01-01"), + expected: false, + }, + { + name: "record mode", + client: &Client{ + TransportMode: TransportModeVCRRecord, + }, + expected: false, + }, + { + name: "replay mode", + client: &Client{ + TransportMode: TransportModeVCRReplay, + }, + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual := IsVCRReplaying(testCase.client) + if actual != testCase.expected { + t.Fatalf("expected %t, got %t", testCase.expected, actual) + } + }) + } +}