Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 14 additions & 90 deletions sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@ import (
"io"
"log"
"math"
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -325,6 +322,10 @@ type Client struct {
// Transport allows overriding the http.RoundTripper used by the client.
// When nil, a default transport will be used.
Transport http.RoundTripper

// TransportMode indicates the mode of the transport, this is being used to set vcr transport mode,
// but can be used by custom transports to modify their behavior based on the mode.
TransportMode TransportMode
}

// NewClient returns a new Client configured with sensible defaults
Expand All @@ -339,80 +340,17 @@ func NewClient(baseUri string, serviceName, apiVersion string) *Client {
}
}

// SetTransport configures the transport to be used by the client
func (c *Client) SetTransport(transport http.RoundTripper) {
// SetTransport configures the transport and transport mode to be used by the client.
func (c *Client) SetTransport(transport http.RoundTripper, mode TransportMode) {
c.Transport = transport
c.TransportMode = mode
}

// SetAuthorizer configures the request authorizer for the client
func (c *Client) SetAuthorizer(authorizer auth.Authorizer) {
c.Authorizer = authorizer
}

// IsVcrReplaying returns true if the provided transport appears to be a VCR recorder in replay mode.
func IsVcrReplaying(transport http.RoundTripper) bool {
if transport == nil {
return false
}

// We use reflection to avoid a hard dependency on the go-vcr package in the SDK.
// We check for "recorder" in the type name and then attempt to call the Mode() method.
t := reflect.TypeOf(transport)
if !strings.Contains(strings.ToLower(t.String()), "recorder") {
return false
}

v := reflect.ValueOf(transport)
// If it's a pointer, get the underlying value (MethodByName works on both)
modeMethod := v.MethodByName("Mode")
if !modeMethod.IsValid() {
return false
}

results := modeMethod.Call(nil)
if len(results) != 1 {
return false
}

// Mode is normally an int (or an alias of int).
// In go-vcr v1/v2, ModeReplay is 1.
// In go-vcr v3/v4, ModeReplayOnly is 1 and ModeReplayWithNewEpisodes is 2.
// Support for RecordOnce (3) mode: skip only if cassette is NOT new.
if results[0].Kind() == reflect.Int {
mode := results[0].Int()
if mode == 1 || mode == 2 {
return true
}

if mode == 3 {
// use reflection to get the unexported cassette field
v := reflect.ValueOf(transport)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Struct {
cassetteField := v.FieldByName("cassette")
if cassetteField.IsValid() {
cassette := cassetteField
if cassette.Kind() == reflect.Ptr && !cassette.IsNil() {
cassette = cassette.Elem()
}

if cassette.Kind() == reflect.Struct {
isNewField := cassette.FieldByName("IsNew")
if isNewField.IsValid() && isNewField.Kind() == reflect.Bool {
// If the cassette is NOT new, we are replaying
return !isNewField.Bool()
}
}
}
}
}
}

return false
}

// SetUserAgent configures the user agent to be included in requests
func (c *Client) SetUserAgent(userAgent string) {
c.UserAgent = userAgent
Expand Down Expand Up @@ -546,6 +484,9 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) {

// Check for failed connections etc and decide if retries are appropriate
if r == nil {
if IsVCRReplayMissError(err) {
return false, err
}
if req.IsIdempotent() {
if !isResourceManagerHost(req) {
return extendedRetryPolicy(r, err)
Expand Down Expand Up @@ -595,12 +536,10 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) {
resp.Response = r
}
}

// If we're running with a VCR transport in REPLAY mode, we mark the response so that the poller knows to skip the delay
if IsVcrReplaying(c.Transport) {
resp.Header.Add("X-Go-Azure-SDK-Skip-Polling-Delay", "true")
// If it's a recorded response, we mark the response so that the poller knows to skip the delay
if c.TransportMode != TransportModeDefault && (IsVCRRecordedResponse(resp.Response) || IsVCRReplaying(c)) {
resp.Header.Set(SkipPollingDelayHeader, "true")
}

// Extract OData from response, intentionally ignoring any errors as it's not crucial to extract
// valid OData at this point (valid json can still error here, such as any non-object literal)
resp.OData, _ = odata.FromResponse(resp.Response)
Expand Down Expand Up @@ -808,22 +747,7 @@ func (c *Client) retryableClient(ctx context.Context, checkRetry retryablehttp.C
if c.Transport != nil {
transport = c.Transport
} else {
transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := &net.Dialer{Resolver: &net.Resolver{}}
return d.DialContext(ctx, network, addr)
},
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
}
transport = GetDefaultHttpTransport()
}

r.HTTPClient = &http.Client{
Expand Down
Loading
Loading