Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 6 additions & 70 deletions sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"runtime"
"strconv"
Expand All @@ -30,6 +29,12 @@ import (
"github.com/hashicorp/go-retryablehttp"
)

const (
// SkipPollingDelayHeader is the HTTP header used to instruct the poller to skip its standard delay wait time.
// This is typically injected via a ResponseMiddleware when running tests in a VCR replay mode.
SkipPollingDelayHeader = "X-Go-Azure-Sdk-Skip-Polling-Delay"
)

// RetryOn404ConsistencyFailureFunc can be used to retry a request when a 404 response is received
func RetryOn404ConsistencyFailureFunc(resp *http.Response, _ *odata.OData) (bool, error) {
return resp != nil && resp.StatusCode == http.StatusNotFound, nil
Expand Down Expand Up @@ -349,70 +354,6 @@ func (c *Client) SetAuthorizer(authorizer auth.Authorizer) {
c.Authorizer = authorizer
}

// IsVcrReplaying returns true if the provided transport appears to be a VCR recorder in replay mode.
func IsVcrReplaying(transport http.RoundTripper) bool {
if transport == nil {
return false
}

// We use reflection to avoid a hard dependency on the go-vcr package in the SDK.
// We check for "recorder" in the type name and then attempt to call the Mode() method.
t := reflect.TypeOf(transport)
if !strings.Contains(strings.ToLower(t.String()), "recorder") {
return false
}

v := reflect.ValueOf(transport)
// If it's a pointer, get the underlying value (MethodByName works on both)
modeMethod := v.MethodByName("Mode")
if !modeMethod.IsValid() {
return false
}

results := modeMethod.Call(nil)
if len(results) != 1 {
return false
}

// Mode is normally an int (or an alias of int).
// In go-vcr v1/v2, ModeReplay is 1.
// In go-vcr v3/v4, ModeReplayOnly is 1 and ModeReplayWithNewEpisodes is 2.
// Support for RecordOnce (3) mode: skip only if cassette is NOT new.
if results[0].Kind() == reflect.Int {
mode := results[0].Int()
if mode == 1 || mode == 2 {
return true
}

if mode == 3 {
// use reflection to get the unexported cassette field
v := reflect.ValueOf(transport)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Struct {
cassetteField := v.FieldByName("cassette")
if cassetteField.IsValid() {
cassette := cassetteField
if cassette.Kind() == reflect.Ptr && !cassette.IsNil() {
cassette = cassette.Elem()
}

if cassette.Kind() == reflect.Struct {
isNewField := cassette.FieldByName("IsNew")
if isNewField.IsValid() && isNewField.Kind() == reflect.Bool {
// If the cassette is NOT new, we are replaying
return !isNewField.Bool()
}
}
}
}
}
}

return false
}

// SetUserAgent configures the user agent to be included in requests
func (c *Client) SetUserAgent(userAgent string) {
c.UserAgent = userAgent
Expand Down Expand Up @@ -596,11 +537,6 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) {
}
}

// If we're running with a VCR transport in REPLAY mode, we mark the response so that the poller knows to skip the delay
if IsVcrReplaying(c.Transport) {
resp.Header.Add("X-Go-Azure-SDK-Skip-Polling-Delay", "true")
}

// Extract OData from response, intentionally ignoring any errors as it's not crucial to extract
// valid OData at this point (valid json can still error here, such as any non-object literal)
resp.OData, _ = odata.FromResponse(resp.Response)
Expand Down
199 changes: 0 additions & 199 deletions sdk/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,202 +573,3 @@ func TestClient_CustomTransport(t *testing.T) {
t.Errorf("expected transport to be hit 1 time, got %d", hitCount)
}
}

func TestClient_VCRHeaderInjected(t *testing.T) {
ctx := context.TODO()

c := NewClient("http://localhost", "testService", "v1.0")
c.DisableRetries = true

mockTransport := &recorderRecorder{
mode: 1, // Replay
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))),
Header: make(http.Header),
Request: req,
}, nil
},
}
c.Transport = mockTransport

reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
Path: "/test",
}

req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatalf("NewRequest error: %v", err)
}

resp, err := req.Execute(ctx)
if err != nil {
t.Fatalf("Execute error: %v", err)
}

if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") != "true" {
t.Errorf("expected skip polling delay header to be injected in replay mode, but wasn't")
}
}

func TestClient_VCRHeaderInjectedInRecordOnceModeReplaying(t *testing.T) {
ctx := context.TODO()

c := NewClient("http://localhost", "testService", "v1.0")
c.DisableRetries = true

mockTransport := &recorderRecorder{
mode: 3, // RecordOnce
cassette: &testCassette{
IsNew: false, // REPLAYING
},
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))),
Header: make(http.Header),
Request: req,
}, nil
},
}
c.Transport = mockTransport

reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
Path: "/test",
}

req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatalf("NewRequest error: %v", err)
}

resp, err := req.Execute(ctx)
if err != nil {
t.Fatalf("Execute error: %v", err)
}

if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") != "true" {
t.Errorf("expected skip polling delay header to be injected in RecordOnce replaying mode, but wasn't")
}
}

func TestClient_VCRHeaderNotInjectedInRecordOnceModeRecording(t *testing.T) {
ctx := context.TODO()

c := NewClient("http://localhost", "testService", "v1.0")
c.DisableRetries = true

mockTransport := &recorderRecorder{
mode: 3, // RecordOnce
cassette: &testCassette{
IsNew: true, // RECORDING
},
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))),
Header: make(http.Header),
Request: req,
}, nil
},
}
c.Transport = mockTransport

reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
Path: "/test",
}

req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatalf("NewRequest error: %v", err)
}

resp, err := req.Execute(ctx)
if err != nil {
t.Fatalf("Execute error: %v", err)
}

if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") == "true" {
t.Errorf("expected skip polling delay header NOT to be injected in RecordOnce recording mode, but it was")
}
}

func TestClient_VCRHeaderNotInjectedInRecordMode(t *testing.T) {
ctx := context.TODO()

c := NewClient("http://localhost", "testService", "v1.0")
c.DisableRetries = true

mockTransport := &recorderRecorder{
mode: 0, // Record
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(`{"success": true}`))),
Header: make(http.Header),
Request: req,
}, nil
},
}
c.Transport = mockTransport

reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
Path: "/test",
}

req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatalf("NewRequest error: %v", err)
}

resp, err := req.Execute(ctx)
if err != nil {
t.Fatalf("Execute error: %v", err)
}

if resp.Header.Get("X-Go-Azure-SDK-Skip-Polling-Delay") == "true" {
t.Errorf("expected skip polling delay header NOT to be injected in record mode, but it was")
}
}

type recorderRecorder struct {
roundTripFunc func(*http.Request) (*http.Response, error)
mode int
cassette *testCassette
}

type testCassette struct {
IsNew bool
}

func (rt *recorderRecorder) Mode() int {
return rt.mode
}

func (rt *recorderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
if rt.roundTripFunc != nil {
return rt.roundTripFunc(req)
}
return nil, fmt.Errorf("roundTripFunc missing from mock")
}
15 changes: 6 additions & 9 deletions sdk/client/dataplane/poller_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type deletePoller struct {
resourcePath string
}

func deletePollerFromResponse(response *client.Response, client *Client, pollingInterval time.Duration) (*deletePoller, error) {
func deletePollerFromResponse(response *client.Response, c *Client, pollingInterval time.Duration) (*deletePoller, error) {
// if we've gotten to this point then we're polling against a Resource Manager resource/operation of some kind
// we next need to determine if the current URI is a Resource Manager resource, or if we should be polling on the
// resource (e.g. `/my/resource`) rather than an operation on the resource (e.g. `/my/resource/start`)
Expand All @@ -47,9 +47,13 @@ func deletePollerFromResponse(response *client.Response, client *Client, polling
return nil, fmt.Errorf("determining Resource Manager Resource Path from %q: %+v", originalUri, err)
}

if s, ok := response.Header[client.SkipPollingDelayHeader]; ok && s[0] == "true" {
pollingInterval = 0
}

return &deletePoller{
apiVersion: apiVersion,
client: client,
client: c,
initialRetryDuration: pollingInterval,
originalUri: originalUri,
resourcePath: *resourcePath,
Expand Down Expand Up @@ -106,13 +110,6 @@ func (p deletePoller) Poll(ctx context.Context) (result *pollers.PollResult, err
return
}

func (p deletePoller) SkipDelay() bool {
if p.client != nil {
return client.IsVcrReplaying(p.client.Transport)
}
return false
}

var _ client.Options = deleteOptions{}

type deleteOptions struct {
Expand Down
15 changes: 6 additions & 9 deletions sdk/client/dataplane/poller_lro.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ func pollingUriForLongRunningOperation(resp *client.Response) string {
return pollingUrl
}

func longRunningOperationPollerFromResponse(resp *client.Response, client *client.Client) (*longRunningOperationPoller, error) {
func longRunningOperationPollerFromResponse(resp *client.Response, c *client.Client) (*longRunningOperationPoller, error) {
poller := longRunningOperationPoller{
client: client,
client: c,
initialRetryDuration: 10 * time.Second,
maxDroppedConnections: 3,
}
Expand Down Expand Up @@ -72,6 +72,10 @@ func longRunningOperationPollerFromResponse(resp *client.Response, client *clien
}
}

if s, ok := resp.Header[client.SkipPollingDelayHeader]; ok && s[0] == "true" {
poller.initialRetryDuration = 0
}

return &poller, nil
}

Expand Down Expand Up @@ -214,13 +218,6 @@ func (p *longRunningOperationPoller) Poll(ctx context.Context) (result *pollers.
return result, nil
}

func (p *longRunningOperationPoller) SkipDelay() bool {
if p.client != nil {
return client.IsVcrReplaying(p.client.Transport)
}
return false
}

type operationResult struct {
Name *string `json:"name"`
// Some APIs (such as CosmosDbPostgreSQLCluster) return a DateTime value that doesn't match RFC3339
Expand Down
Loading
Loading