From b93238eeebe6bf46bbcc2de156ec2ed85275e5b7 Mon Sep 17 00:00:00 2001 From: Dave Mihalcik Date: Mon, 8 Jun 2026 11:22:51 -0400 Subject: [PATCH 1/3] feat(sdk): add DPoP client support with HTTP RoundTripper (DSPX-3397) Implements RFC 9449 DPoP (Demonstrating Proof-of-Possession) for the Go SDK: - Add DPoPTransport as an http.RoundTripper that wraps any transport - Generate ES256/RS256 proofs with jti, htm, htu, iat claims for all requests - Add ath claim (access token hash) for resource endpoint calls - Handle server-issued DPoP-Nonce challenges with automatic retry - Cache nonces per-origin and refresh from successful responses - Normalize URIs per RFC 9449 (lowercase scheme/host, strip default ports) - Integrate into SDK's HTTP client construction via NewDPoPHTTPClient - Add SupportedFeatures() function for xtest feature detection All requests through the SDK now include DPoP proofs when credentials are configured. Token endpoint requests omit the ath claim; resource requests include both Authorization: DPoP header and the DPoP proof header. Co-Authored-By: Claude Sonnet 4.5 --- sdk/auth/dpop_transport.go | 275 ++++++++++++++++++++++++ sdk/auth/dpop_transport_test.go | 357 ++++++++++++++++++++++++++++++++ sdk/sdk.go | 34 ++- sdk/version.go | 9 + 4 files changed, 673 insertions(+), 2 deletions(-) create mode 100644 sdk/auth/dpop_transport.go create mode 100644 sdk/auth/dpop_transport_test.go diff --git a/sdk/auth/dpop_transport.go b/sdk/auth/dpop_transport.go new file mode 100644 index 0000000000..06f9def96f --- /dev/null +++ b/sdk/auth/dpop_transport.go @@ -0,0 +1,275 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +// DPoPTransport wraps an http.RoundTripper to add DPoP (RFC 9449) proof tokens +// to HTTP requests. It generates proofs for both token endpoint calls and +// resource endpoint calls, handling server-issued nonces with automatic retry. +type DPoPTransport struct { + // Base is the underlying transport. If nil, http.DefaultTransport is used. + Base http.RoundTripper + + // DPoPKey is the private key used to sign DPoP proofs. + DPoPKey jwk.Key + + // TokenSource provides access tokens for resource requests. + // When the token is DPoP-bound (token_type=DPoP), the transport + // sets Authorization: DPoP and includes the ath claim. + TokenSource AccessTokenSource + + // TokenEndpoint is the OAuth token endpoint URL. + // Requests to this endpoint are treated as token requests + // and do not include the ath claim. + TokenEndpoint string + + nonceMu sync.RWMutex + // nonceCache stores server-issued nonces by origin (scheme://host:port) + nonceCache map[string]string +} + +// RoundTrip implements http.RoundTripper, adding DPoP proofs to requests. +func (t *DPoPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.Base == nil { + t.Base = http.DefaultTransport + } + + if t.nonceCache == nil { + t.nonceMu.Lock() + if t.nonceCache == nil { + t.nonceCache = make(map[string]string) + } + t.nonceMu.Unlock() + } + + // Clone request to avoid modifying the original + req2 := cloneRequest(req) + + // Determine if this is a token endpoint request + isTokenRequest := t.isTokenEndpointRequest(req2.URL) + + // Get cached nonce for this origin + origin := getOrigin(req2.URL) + nonce := t.getCachedNonce(origin) + + // Generate and add DPoP proof + if err := t.addDPoPProof(req2, nonce, isTokenRequest); err != nil { + return nil, fmt.Errorf("failed to add DPoP proof: %w", err) + } + + // Make the request + resp, err := t.Base.RoundTrip(req2) + if err != nil { + return resp, err + } + + // Handle DPoP-Nonce challenge (RFC 9449 §8) + if resp.StatusCode == http.StatusUnauthorized { + if newNonce := resp.Header.Get("DPoP-Nonce"); newNonce != "" { + // Check if this was a retry with a nonce already + if nonce != "" { + // Already tried with a nonce, don't retry again + return resp, nil + } + + // Cache the new nonce + t.setCachedNonce(origin, newNonce) + + // Close the failed response body + resp.Body.Close() + + // Clone the original request again for retry + req3 := cloneRequest(req) + + // Regenerate proof with nonce + if err := t.addDPoPProof(req3, newNonce, isTokenRequest); err != nil { + return nil, fmt.Errorf("failed to add DPoP proof with nonce: %w", err) + } + + // Retry the request + return t.Base.RoundTrip(req3) + } + } + + // Update cached nonce from successful responses + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + if newNonce := resp.Header.Get("DPoP-Nonce"); newNonce != "" { + t.setCachedNonce(origin, newNonce) + } + } + + return resp, nil +} + +// addDPoPProof generates and adds DPoP proof to the request headers. +func (t *DPoPTransport) addDPoPProof(req *http.Request, nonce string, isTokenRequest bool) error { + // Normalize the htu (RFC 9449 HTTP URI Normalization) + htu := normalizeURI(req.URL) + + // Build base proof claims + builder := jwt.NewBuilder(). + Claim("jti", uuid.NewString()). + Claim("htm", req.Method). + Claim("htu", htu). + IssuedAt(time.Now()) + + // Add nonce if provided + if nonce != "" { + builder = builder.Claim("nonce", nonce) + } + + // For resource requests (not token endpoint), add ath claim + var accessToken string + if !isTokenRequest && t.TokenSource != nil { + at, err := t.TokenSource.AccessToken(req.Context(), nil) + if err != nil { + return fmt.Errorf("failed to get access token: %w", err) + } + accessToken = string(at) + + // Calculate ath = base64url(SHA-256(access_token)) + h := sha256.New() + h.Write([]byte(accessToken)) + ath := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + builder = builder.Claim("ath", ath) + } + + // Build the token + token, err := builder.Build() + if err != nil { + return fmt.Errorf("failed to build DPoP token: %w", err) + } + + // Get public key for jwk header + publicKey, err := t.DPoPKey.PublicKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + + // Create headers + headers := jws.NewHeaders() + if err := headers.Set(jws.JWKKey, publicKey); err != nil { + return fmt.Errorf("failed to set jwk header: %w", err) + } + if err := headers.Set(jws.TypeKey, "dpop+jwt"); err != nil { + return fmt.Errorf("failed to set typ header: %w", err) + } + if err := headers.Set(jws.AlgorithmKey, t.DPoPKey.Algorithm()); err != nil { + return fmt.Errorf("failed to set alg header: %w", err) + } + + // Sign the token + signedToken, err := jwt.Sign(token, jwt.WithKey(t.DPoPKey.Algorithm(), t.DPoPKey, jws.WithProtectedHeaders(headers))) + if err != nil { + return fmt.Errorf("failed to sign DPoP token: %w", err) + } + + // Add DPoP header + req.Header.Set("DPoP", string(signedToken)) + + // For resource requests, set Authorization header + if !isTokenRequest && accessToken != "" { + req.Header.Set("Authorization", "DPoP "+accessToken) + } + + return nil +} + +// isTokenEndpointRequest checks if the URL matches the configured token endpoint. +func (t *DPoPTransport) isTokenEndpointRequest(u *url.URL) bool { + if t.TokenEndpoint == "" { + return false + } + tokenURL, err := url.Parse(t.TokenEndpoint) + if err != nil { + return false + } + return u.Scheme == tokenURL.Scheme && + u.Host == tokenURL.Host && + u.Path == tokenURL.Path +} + +// normalizeURI normalizes the URI per RFC 9449 HTTP URI Normalization: +// - Lowercase scheme and host +// - Remove default ports (80 for http, 443 for https) +// - Strip query and fragment +func normalizeURI(u *url.URL) string { + scheme := strings.ToLower(u.Scheme) + host := strings.ToLower(u.Host) + + // Remove default ports + if (scheme == "http" && strings.HasSuffix(host, ":80")) || + (scheme == "https" && strings.HasSuffix(host, ":443")) { + host = host[:strings.LastIndex(host, ":")] + } + + return fmt.Sprintf("%s://%s%s", scheme, host, u.Path) +} + +// getOrigin returns the origin (scheme://host:port) from a URL. +func getOrigin(u *url.URL) string { + return fmt.Sprintf("%s://%s", u.Scheme, u.Host) +} + +// getCachedNonce retrieves the cached nonce for an origin. +func (t *DPoPTransport) getCachedNonce(origin string) string { + t.nonceMu.RLock() + defer t.nonceMu.RUnlock() + return t.nonceCache[origin] +} + +// setCachedNonce stores a nonce for an origin. +func (t *DPoPTransport) setCachedNonce(origin, nonce string) { + t.nonceMu.Lock() + defer t.nonceMu.Unlock() + t.nonceCache[origin] = nonce +} + +// cloneRequest creates a shallow clone of the request. +func cloneRequest(req *http.Request) *http.Request { + req2 := req.Clone(req.Context()) + // Clone headers to avoid modifying the original + req2.Header = req.Header.Clone() + return req2 +} + +// NewDPoPHTTPClient creates a new HTTP client with DPoP transport wrapping. +// The client will automatically add DPoP proofs to all requests. +func NewDPoPHTTPClient(baseClient *http.Client, dpopKey jwk.Key, tokenSource AccessTokenSource, tokenEndpoint string) *http.Client { + if baseClient == nil { + baseClient = http.DefaultClient + } + + transport := baseClient.Transport + if transport == nil { + transport = http.DefaultTransport + } + + dpopTransport := &DPoPTransport{ + Base: transport, + DPoPKey: dpopKey, + TokenSource: tokenSource, + TokenEndpoint: tokenEndpoint, + } + + return &http.Client{ + Transport: dpopTransport, + CheckRedirect: baseClient.CheckRedirect, + Jar: baseClient.Jar, + Timeout: baseClient.Timeout, + } +} diff --git a/sdk/auth/dpop_transport_test.go b/sdk/auth/dpop_transport_test.go new file mode 100644 index 0000000000..72613196e1 --- /dev/null +++ b/sdk/auth/dpop_transport_test.go @@ -0,0 +1,357 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +// mockTokenSource implements AccessTokenSource for testing +type mockTokenSource struct { + token string +} + +func (m *mockTokenSource) AccessToken(_ context.Context, _ *http.Client) (AccessToken, error) { + return AccessToken(m.token), nil +} + +func (m *mockTokenSource) MakeToken(_ func(jwk.Key) ([]byte, error)) ([]byte, error) { + // Not used in transport tests + return nil, nil +} + +func generateTestKey(t *testing.T) jwk.Key { + t.Helper() + rawKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate RSA key: %v", err) + } + + key, err := jwk.FromRaw(rawKey) + if err != nil { + t.Fatalf("failed to create JWK: %v", err) + } + + if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { + t.Fatalf("failed to set algorithm: %v", err) + } + + return key +} + +func parseDPoPProof(t *testing.T, proofStr string, key jwk.Key) jwt.Token { + t.Helper() + + token, err := jwt.Parse([]byte(proofStr), jwt.WithKey(jwa.RS256, key)) + if err != nil { + t.Fatalf("failed to parse DPoP proof: %v", err) + } + + return token +} + +func TestDPoPTransport_AddsProofToRequests(t *testing.T) { + key := generateTestKey(t) + ts := &mockTokenSource{token: "test-access-token"} + + called := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + + // Verify DPoP header exists + dpopHeader := r.Header.Get("DPoP") + if dpopHeader == "" { + t.Error("DPoP header not present") + return + } + + // Verify Authorization header + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "DPoP ") { + t.Errorf("Authorization header = %q, want prefix 'DPoP '", authHeader) + } + + // Parse and verify the proof + publicKey, err := key.PublicKey() + if err != nil { + t.Fatalf("failed to get public key: %v", err) + } + + token := parseDPoPProof(t, dpopHeader, publicKey) + + // Check htm claim + if htm, ok := token.Get("htm"); !ok || htm != "GET" { + t.Errorf("htm claim = %v, want 'GET'", htm) + } + + // Check htu claim (should be normalized) + htu, ok := token.Get("htu") + if !ok { + t.Error("htu claim missing") + } else if htuStr, ok := htu.(string); !ok { + t.Errorf("htu claim not a string: %v", htu) + } else if htuStr == "" { + t.Error("htu claim is empty") + } + + // Check ath claim (access token hash) + if ath, ok := token.Get("ath"); !ok { + t.Error("ath claim missing") + } else { + expectedHash := sha256.Sum256([]byte("test-access-token")) + expectedATH := base64.RawURLEncoding.EncodeToString(expectedHash[:]) + if ath != expectedATH { + t.Errorf("ath claim = %v, want %v", ath, expectedATH) + } + } + + // Check jti claim + if jti, ok := token.Get("jti"); !ok || jti == "" { + t.Error("jti claim missing or empty") + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + transport := &DPoPTransport{ + Base: http.DefaultTransport, + DPoPKey: key, + TokenSource: ts, + } + + client := &http.Client{Transport: transport} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !called { + t.Error("server handler was not called") + } +} + +func TestDPoPTransport_NonceRetry(t *testing.T) { + key := generateTestKey(t) + ts := &mockTokenSource{token: "test-token"} + + callCount := 0 + nonce := "test-nonce-12345" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + + dpopHeader := r.Header.Get("DPoP") + if dpopHeader == "" { + t.Error("DPoP header not present") + w.WriteHeader(http.StatusBadRequest) + return + } + + publicKey, err := key.PublicKey() + if err != nil { + t.Fatalf("failed to get public key: %v", err) + } + + token := parseDPoPProof(t, dpopHeader, publicKey) + + if callCount == 1 { + // First request should not have nonce + if _, ok := token.Get("nonce"); ok { + t.Error("first request should not have nonce claim") + } + + // Send 401 with nonce challenge + w.Header().Set("DPoP-Nonce", nonce) + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Second request should have the nonce + if nonceVal, ok := token.Get("nonce"); !ok { + t.Error("second request missing nonce claim") + } else if nonceVal != nonce { + t.Errorf("nonce claim = %v, want %v", nonceVal, nonce) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + transport := &DPoPTransport{ + Base: http.DefaultTransport, + DPoPKey: key, + TokenSource: ts, + } + + client := &http.Client{Transport: transport} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if callCount != 2 { + t.Errorf("expected 2 calls (initial + retry), got %d", callCount) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("final status = %d, want %d", resp.StatusCode, http.StatusOK) + } +} + +func TestDPoPTransport_URINormalization(t *testing.T) { + tests := []struct { + name string + url string + expected string + }{ + { + name: "https default port", + url: "https://example.com:443/path", + expected: "https://example.com/path", + }, + { + name: "http default port", + url: "http://example.com:80/path", + expected: "http://example.com/path", + }, + { + name: "https non-default port", + url: "https://example.com:8443/path", + expected: "https://example.com:8443/path", + }, + { + name: "uppercase scheme and host", + url: "HTTPS://EXAMPLE.COM/Path", + expected: "https://example.com/Path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := generateTestKey(t) + ts := &mockTokenSource{token: "test-token"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dpopHeader := r.Header.Get("DPoP") + publicKey, err := key.PublicKey() + if err != nil { + t.Fatalf("failed to get public key: %v", err) + } + + token := parseDPoPProof(t, dpopHeader, publicKey) + + htu, ok := token.Get("htu") + if !ok { + t.Fatal("htu claim missing") + } + + // The htu should have normalized the URL + htuStr := htu.(string) + if !strings.Contains(htuStr, "/path") { + t.Errorf("htu = %s, want to contain normalized path", htuStr) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + transport := &DPoPTransport{ + Base: http.DefaultTransport, + DPoPKey: key, + TokenSource: ts, + } + + client := &http.Client{Transport: transport} + + // Use the server URL but replace path + testURL := server.URL + "/path" + req, err := http.NewRequest(http.MethodGet, testURL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + } + if err != nil { + t.Fatalf("request failed: %v", err) + } + }) + } +} + +func TestDPoPTransport_TokenEndpointNoATH(t *testing.T) { + key := generateTestKey(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dpopHeader := r.Header.Get("DPoP") + if dpopHeader == "" { + t.Error("DPoP header not present") + w.WriteHeader(http.StatusBadRequest) + return + } + + publicKey, err := key.PublicKey() + if err != nil { + t.Fatalf("failed to get public key: %v", err) + } + + token := parseDPoPProof(t, dpopHeader, publicKey) + + // Token endpoint requests should NOT have ath claim + if _, ok := token.Get("ath"); ok { + t.Error("token endpoint request should not have ath claim") + } + + // Should not have Authorization header for token endpoint + if auth := r.Header.Get("Authorization"); auth != "" { + t.Errorf("token endpoint should not have Authorization header, got %q", auth) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + transport := &DPoPTransport{ + Base: http.DefaultTransport, + DPoPKey: key, + TokenSource: &mockTokenSource{token: "test-token"}, + TokenEndpoint: server.URL, + } + + client := &http.Client{Transport: transport} + req, err := http.NewRequest(http.MethodPost, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() +} diff --git a/sdk/sdk.go b/sdk/sdk.go index fa809e6e12..28eaf86c2c 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -16,6 +16,7 @@ import ( "sync" "connectrpc.com/connect" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" @@ -199,8 +200,19 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { if err != nil { return nil, err } + + // Wrap HTTP client with DPoP transport for resource requests + httpClient := cfg.httpClient + if accessTokenSource != nil && cfg.dpopKey != nil { + dpopKey, err := getDPoPJWK(cfg.dpopKey) + if err != nil { + return nil, fmt.Errorf("failed to create DPoP JWK: %w", err) + } + httpClient = auth.NewDPoPHTTPClient(cfg.httpClient, dpopKey, accessTokenSource, cfg.tokenEndpoint) + } + if accessTokenSource != nil { - interceptor := auth.NewTokenAddingInterceptorWithClient(accessTokenSource, cfg.httpClient) + interceptor := auth.NewTokenAddingInterceptorWithClient(accessTokenSource, httpClient) uci = append(uci, interceptor.AddCredentialsConnect()) } @@ -208,7 +220,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { if cfg.coreConn != nil { platformConn = cfg.coreConn } else { - platformConn = &ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} + platformConn = &ConnectRPCConnection{Endpoint: platformEndpoint, Client: httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} } if cfg.entityResolutionConn != nil { @@ -248,6 +260,24 @@ func IsPlatformEndpointMalformed(e string) bool { return false } +func getDPoPJWK(dpopKey *ocrypto.RsaKeyPair) (jwk.Key, error) { + dpopPrivateKeyPEM, err := dpopKey.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("error getting dpop private key: %w", err) + } + + key, err := jwk.ParseKey([]byte(dpopPrivateKeyPEM), jwk.WithPEM(true)) + if err != nil { + return nil, fmt.Errorf("error creating JWK: %w", err) + } + + if err := key.Set(jwk.AlgorithmKey, "RS256"); err != nil { + return nil, fmt.Errorf("error setting key algorithm: %w", err) + } + + return key, nil +} + func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { if c.customAccessTokenSource != nil { return c.customAccessTokenSource, nil diff --git a/sdk/version.go b/sdk/version.go index 1575ae1d4c..f22769fd2d 100644 --- a/sdk/version.go +++ b/sdk/version.go @@ -9,3 +9,12 @@ const ( // The three-part semantic version number of this SDK Version = "0.21.0" // x-release-please-version ) + +// SupportedFeatures returns a list of optional features supported by this SDK build. +// Used by xtest integration harness for feature detection. +func SupportedFeatures() []string { + return []string{ + "dpop", // RFC 9449 DPoP (Demonstrating Proof-of-Possession) + "connectrpc", // Connect RPC protocol support + } +} From b2d79cb2791c4879abc5d82c63847f727a4ec262 Mon Sep 17 00:00:00 2001 From: Dave Mihalcik Date: Tue, 9 Jun 2026 13:55:58 -0400 Subject: [PATCH 2/3] feat(otdfctl): add --dpop and --dpop-key flags to encrypt/decrypt (DSPX-3397) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Exposes DPoP algorithm/key selection via CLI flags on `otdfctl encrypt` and `otdfctl decrypt`, supporting ES256 (default), ES384, ES512, RS256, RS384, and RS512. Bare `--dpop` defaults to ES256 per RFC 9449 §4.2. `--dpop-key ` loads a PEM private key (algorithm inferred from key type). Both flags can be combined to override the inferred algorithm. SDK changes: - Add sdk/dpop_key.go: generateDPoPKeyForAlg, loadDPoPKeyFromPEM, resolveDPoPKey helpers - Add WithDPoPAlgorithm, WithDPoPKeyPEM, WithDPoPJWK SDK options - Thread custom JWK through buildIDPTokenSource and DPoPTransport setup; falls back to auto-generated RSA when no custom key is configured - Add JWK-accepting token source constructors for all four source types otdfctl changes: - Register --dpop (NoOptDefVal="ES256") and --dpop-key flags on encrypt/decrypt; update man docs accordingly - handlers.WithExtraSDKOpts appends (not replaces) SDK options - common.NewHandler accepts variadic extraSDKOpts (backward compatible) Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Dave Mihalcik --- otdfctl/cmd/common/common.go | 8 +- otdfctl/cmd/tdf/decrypt.go | 14 ++- otdfctl/cmd/tdf/dpop.go | 30 +++++ otdfctl/cmd/tdf/encrypt.go | 14 ++- otdfctl/docs/man/decrypt/_index.md | 10 ++ otdfctl/docs/man/encrypt/_index.md | 10 ++ otdfctl/pkg/handlers/sdk.go | 18 ++- sdk/auth/dpop_transport.go | 1 - sdk/dpop_key.go | 164 +++++++++++++++++++++++++ sdk/dpop_key_test.go | 147 ++++++++++++++++++++++ sdk/idp_access_token_source.go | 20 +++ sdk/idp_cert_exchange.go | 18 +++ sdk/idp_oauth_access_token_source.go | 9 ++ sdk/idp_token_exchange_token_source.go | 20 +++ sdk/options.go | 31 +++++ sdk/sdk.go | 47 +++++-- sdk/version.go | 2 +- 17 files changed, 545 insertions(+), 18 deletions(-) create mode 100644 otdfctl/cmd/tdf/dpop.go create mode 100644 sdk/dpop_key.go create mode 100644 sdk/dpop_key_test.go diff --git a/otdfctl/cmd/common/common.go b/otdfctl/cmd/common/common.go index e7515a5a20..2e1611a1dd 100644 --- a/otdfctl/cmd/common/common.go +++ b/otdfctl/cmd/common/common.go @@ -86,7 +86,7 @@ func InitProfile(c *cli.Cli) *profiles.OtdfctlProfileStore { // TODO make this a preRun hook // //nolint:nestif // separate refactor [https://github.com/opentdf/otdfctl/issues/383] -func NewHandler(c *cli.Cli) handlers.Handler { +func NewHandler(c *cli.Cli, extraSDKOpts ...sdk.Option) handlers.Handler { // if global flags are set then validate and create a temporary profile in memory var cp *profiles.OtdfctlProfileStore @@ -209,7 +209,11 @@ func NewHandler(c *cli.Cli) handlers.Handler { cli.ExitWithError("Failed to get access token.", err) } - h, err := handlers.New(handlers.WithProfile(cp)) + handlerFuncs := []handlers.HandlerOptsFunc{handlers.WithProfile(cp)} + if len(extraSDKOpts) > 0 { + handlerFuncs = append(handlerFuncs, handlers.WithExtraSDKOpts(extraSDKOpts...)) + } + h, err := handlers.New(handlerFuncs...) if err != nil { cli.ExitWithError("Unexpected error", err) } diff --git a/otdfctl/cmd/tdf/decrypt.go b/otdfctl/cmd/tdf/decrypt.go index f6cd7beded..54b01bb24c 100644 --- a/otdfctl/cmd/tdf/decrypt.go +++ b/otdfctl/cmd/tdf/decrypt.go @@ -23,7 +23,7 @@ var ( func decryptRun(cmd *cobra.Command, args []string) { c := cli.New(cmd, args, cli.WithPrintJSON()) - h := common.NewHandler(c) + h := common.NewHandler(c, dpopSDKOpts(c)...) defer h.Close() output := c.Flags.GetOptionalString("out") @@ -133,6 +133,18 @@ func InitDecryptCommand() { nil, decryptDoc.GetDocFlag("kas-allowlist").Description, ) + decryptDoc.Flags().String( + decryptDoc.GetDocFlag("dpop").Name, + decryptDoc.GetDocFlag("dpop").Default, + decryptDoc.GetDocFlag("dpop").Description, + ) + // NoOptDefVal enables bare --dpop (without =value) to default to ES256. + decryptDoc.Flags().Lookup(decryptDoc.GetDocFlag("dpop").Name).NoOptDefVal = "ES256" + decryptDoc.Flags().String( + decryptDoc.GetDocFlag("dpop-key").Name, + decryptDoc.GetDocFlag("dpop-key").Default, + decryptDoc.GetDocFlag("dpop-key").Description, + ) decryptDoc.GroupID = TDF } diff --git a/otdfctl/cmd/tdf/dpop.go b/otdfctl/cmd/tdf/dpop.go new file mode 100644 index 0000000000..dc6880c5c9 --- /dev/null +++ b/otdfctl/cmd/tdf/dpop.go @@ -0,0 +1,30 @@ +package tdf + +import ( + "os" + + "github.com/opentdf/platform/otdfctl/pkg/cli" + "github.com/opentdf/platform/sdk" +) + +// dpopSDKOpts reads --dpop and --dpop-key flags from the CLI and returns the +// corresponding SDK options. Returns an empty slice when DPoP is not configured. +func dpopSDKOpts(c *cli.Cli) []sdk.Option { + dpopAlg := c.Flags.GetOptionalString("dpop") + dpopKeyPath := c.Flags.GetOptionalString("dpop-key") + + var opts []sdk.Option + if dpopKeyPath != "" { + pemBytes, err := os.ReadFile(dpopKeyPath) + if err != nil { + cli.ExitWithError("Failed to read DPoP key file", err) + } + opts = append(opts, sdk.WithDPoPKeyPEM(pemBytes)) + if dpopAlg != "" { + opts = append(opts, sdk.WithDPoPAlgorithm(dpopAlg)) + } + } else if dpopAlg != "" { + opts = append(opts, sdk.WithDPoPAlgorithm(dpopAlg)) + } + return opts +} diff --git a/otdfctl/cmd/tdf/encrypt.go b/otdfctl/cmd/tdf/encrypt.go index 3935825f28..744b597e33 100644 --- a/otdfctl/cmd/tdf/encrypt.go +++ b/otdfctl/cmd/tdf/encrypt.go @@ -26,7 +26,7 @@ var ( func encryptRun(cmd *cobra.Command, args []string) { c := cli.New(cmd, args, cli.WithPrintJSON()) - h := common.NewHandler(c) + h := common.NewHandler(c, dpopSDKOpts(c)...) defer h.Close() var filePath string @@ -191,5 +191,17 @@ func InitEncryptCommand() { encryptDoc.GetDocFlag("target-mode").Default, encryptDoc.GetDocFlag("target-mode").Description, ) + encryptDoc.Flags().String( + encryptDoc.GetDocFlag("dpop").Name, + encryptDoc.GetDocFlag("dpop").Default, + encryptDoc.GetDocFlag("dpop").Description, + ) + // NoOptDefVal enables bare --dpop (without =value) to default to ES256. + encryptDoc.Flags().Lookup(encryptDoc.GetDocFlag("dpop").Name).NoOptDefVal = "ES256" + encryptDoc.Flags().String( + encryptDoc.GetDocFlag("dpop-key").Name, + encryptDoc.GetDocFlag("dpop-key").Default, + encryptDoc.GetDocFlag("dpop-key").Description, + ) encryptDoc.GroupID = TDF } diff --git a/otdfctl/docs/man/decrypt/_index.md b/otdfctl/docs/man/decrypt/_index.md index 2b7c7b6f22..096f115805 100644 --- a/otdfctl/docs/man/decrypt/_index.md +++ b/otdfctl/docs/man/decrypt/_index.md @@ -27,6 +27,16 @@ command: EXPERIMENTAL: path to JSON file of keys to verify signed assertions. See examples for more information. - name: kas-allowlist description: A custom allowlist of comma-separated KAS Urls, e.g. `https://example.com/kas,http://localhost:8080`. If none specified, the platform will use the list of KASes in the KAS registry. To ignore the allowlist, use a quoted wildcard e.g. `--kas-allowlist '*'` **WARNING:** Bypassing the allowlist may expose you to potential security risks, as untrusted KAS URLs could be used. + - name: dpop + description: > + Enable DPoP (RFC 9449) sender-constrained tokens. Use bare --dpop for ES256 (default), or + --dpop= for a specific algorithm. Allowed algorithms: ES256, ES384, ES512, RS256, RS384, RS512. + An ephemeral key is generated per session. Combines with --dpop-key to override inferred algorithm. + - name: dpop-key + description: > + Path to a PEM-encoded private key for DPoP. Enables DPoP without requiring --dpop. + Algorithm is inferred from the key type (EC → ES256/384/512, RSA → RS256). + Use --dpop= to override the inferred algorithm. --- Decrypt a Trusted Data Format (TDF) file and output the contents to stdout or a file in the current working directory. diff --git a/otdfctl/docs/man/encrypt/_index.md b/otdfctl/docs/man/encrypt/_index.md index 36fa33647e..8976c7702e 100644 --- a/otdfctl/docs/man/encrypt/_index.md +++ b/otdfctl/docs/man/encrypt/_index.md @@ -37,6 +37,16 @@ command: - name: with-assertions description: > EXPERIMENTAL: JSON string or path to a JSON file of assertions to bind metadata to the TDF. See examples for more information. WARNING: Providing keys in a JSON string is strongly discouraged. If including sensitive keys, instead provide a path to a JSON file containing that information. + - name: dpop + description: > + Enable DPoP (RFC 9449) sender-constrained tokens. Use bare --dpop for ES256 (default), or + --dpop= for a specific algorithm. Allowed algorithms: ES256, ES384, ES512, RS256, RS384, RS512. + An ephemeral key is generated per session. Combines with --dpop-key to override inferred algorithm. + - name: dpop-key + description: > + Path to a PEM-encoded private key for DPoP. Enables DPoP without requiring --dpop. + Algorithm is inferred from the key type (EC → ES256/384/512, RSA → RS256). + Use --dpop= to override the inferred algorithm. --- Build a Trusted Data Format (TDF) with encrypted content from a specified file or input from stdin utilizing OpenTDF platform. diff --git a/otdfctl/pkg/handlers/sdk.go b/otdfctl/pkg/handlers/sdk.go index 2c1e0f7871..032d66b88c 100644 --- a/otdfctl/pkg/handlers/sdk.go +++ b/otdfctl/pkg/handlers/sdk.go @@ -31,9 +31,9 @@ type handlerOpts struct { sdkOpts []sdk.Option } -type handlerOptsFunc func(handlerOpts) handlerOpts +type HandlerOptsFunc func(handlerOpts) handlerOpts -func WithEndpoint(endpoint string, tlsNoVerify bool) handlerOptsFunc { +func WithEndpoint(endpoint string, tlsNoVerify bool) HandlerOptsFunc { return func(c handlerOpts) handlerOpts { c.endpoint = endpoint c.TLSNoVerify = tlsNoVerify @@ -41,7 +41,7 @@ func WithEndpoint(endpoint string, tlsNoVerify bool) handlerOptsFunc { } } -func WithProfile(profile *profiles.OtdfctlProfileStore) handlerOptsFunc { +func WithProfile(profile *profiles.OtdfctlProfileStore) HandlerOptsFunc { return func(c handlerOpts) handlerOpts { c.profile = profile c.endpoint = profile.GetEndpoint() @@ -58,15 +58,23 @@ func WithProfile(profile *profiles.OtdfctlProfileStore) handlerOptsFunc { } } -func WithSDKOpts(opts ...sdk.Option) handlerOptsFunc { +func WithSDKOpts(opts ...sdk.Option) HandlerOptsFunc { return func(c handlerOpts) handlerOpts { c.sdkOpts = opts return c } } +// WithExtraSDKOpts appends additional SDK options without replacing those set by WithProfile. +func WithExtraSDKOpts(opts ...sdk.Option) HandlerOptsFunc { + return func(c handlerOpts) handlerOpts { + c.sdkOpts = append(c.sdkOpts, opts...) + return c + } +} + // Creates a new handler wrapping the SDK, which is authenticated through the cached client-credentials flow tokens -func New(opts ...handlerOptsFunc) (Handler, error) { +func New(opts ...HandlerOptsFunc) (Handler, error) { var o handlerOpts for _, f := range opts { o = f(o) diff --git a/sdk/auth/dpop_transport.go b/sdk/auth/dpop_transport.go index 06f9def96f..0d9dd1a19a 100644 --- a/sdk/auth/dpop_transport.go +++ b/sdk/auth/dpop_transport.go @@ -8,7 +8,6 @@ import ( "net/url" "strings" "sync" - "time" "github.com/google/uuid" diff --git a/sdk/dpop_key.go b/sdk/dpop_key.go new file mode 100644 index 0000000000..9cee49fc97 --- /dev/null +++ b/sdk/dpop_key.go @@ -0,0 +1,164 @@ +package sdk + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" +) + +// Supported DPoP algorithm identifiers (RFC 9449 §4.2). +const ( + dpopAlgES256 = "ES256" + dpopAlgES384 = "ES384" + dpopAlgES512 = "ES512" + dpopAlgRS256 = "RS256" + dpopAlgRS384 = "RS384" + dpopAlgRS512 = "RS512" +) + +const dpopAllowedAlgs = dpopAlgES256 + ", " + dpopAlgES384 + ", " + dpopAlgES512 + ", " + + dpopAlgRS256 + ", " + dpopAlgRS384 + ", " + dpopAlgRS512 + +// generateDPoPKeyForAlg generates an ephemeral DPoP private key for the given algorithm. +// Supported algorithms: ES256, ES384, ES512, RS256, RS384, RS512. +func generateDPoPKeyForAlg(alg string) (jwk.Key, error) { + switch alg { + case dpopAlgES256: + return generateECDSAKey(elliptic.P256(), jwa.ES256) + case dpopAlgES384: + return generateECDSAKey(elliptic.P384(), jwa.ES384) + case dpopAlgES512: + return generateECDSAKey(elliptic.P521(), jwa.ES512) + case dpopAlgRS256: + return generateRSAKey(jwa.RS256) + case dpopAlgRS384: + return generateRSAKey(jwa.RS384) + case dpopAlgRS512: + return generateRSAKey(jwa.RS512) + default: + return nil, fmt.Errorf("unsupported DPoP algorithm %q; allowed: %s", alg, dpopAllowedAlgs) + } +} + +func generateECDSAKey(curve elliptic.Curve, alg jwa.SignatureAlgorithm) (jwk.Key, error) { + rawKey, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate ECDSA key: %w", err) + } + key, err := jwk.FromRaw(rawKey) + if err != nil { + return nil, fmt.Errorf("failed to create JWK from ECDSA key: %w", err) + } + if err := key.Set(jwk.AlgorithmKey, alg); err != nil { + return nil, fmt.Errorf("failed to set algorithm on ECDSA JWK: %w", err) + } + return key, nil +} + +func generateRSAKey(alg jwa.SignatureAlgorithm) (jwk.Key, error) { + const rsaBits = 2048 + rawKey, err := rsa.GenerateKey(rand.Reader, rsaBits) + if err != nil { + return nil, fmt.Errorf("failed to generate RSA key: %w", err) + } + key, err := jwk.FromRaw(rawKey) + if err != nil { + return nil, fmt.Errorf("failed to create JWK from RSA key: %w", err) + } + if err := key.Set(jwk.AlgorithmKey, alg); err != nil { + return nil, fmt.Errorf("failed to set algorithm on RSA JWK: %w", err) + } + return key, nil +} + +// loadDPoPKeyFromPEM parses a PEM-encoded private key and returns it as a jwk.Key. +// The DPoP algorithm is inferred from the key type: +// - EC P-256 → ES256, P-384 → ES384, P-521 → ES512 +// - RSA → RS256 +func loadDPoPKeyFromPEM(pemBytes []byte) (jwk.Key, error) { + key, err := jwk.ParseKey(pemBytes, jwk.WithPEM(true)) + if err != nil { + return nil, fmt.Errorf("failed to parse DPoP key PEM: %w", err) + } + + // Infer algorithm when not already set in the PEM + if key.Algorithm() == jwa.NoSignature || key.Algorithm().String() == "" { + alg, err := inferDPoPAlgorithm(key) + if err != nil { + return nil, err + } + if err := key.Set(jwk.AlgorithmKey, alg); err != nil { + return nil, fmt.Errorf("failed to set inferred algorithm on DPoP JWK: %w", err) + } + } + + return key, nil +} + +func inferDPoPAlgorithm(key jwk.Key) (jwa.SignatureAlgorithm, error) { + switch key.KeyType() { //nolint:exhaustive // only EC and RSA are valid for DPoP (RFC 9449 §4.2) + case jwa.EC: + var rawKey ecdsa.PrivateKey + if err := key.Raw(&rawKey); err != nil { + return "", fmt.Errorf("failed to get raw EC key for algorithm inference: %w", err) + } + switch rawKey.Curve { + case elliptic.P256(): + return jwa.ES256, nil + case elliptic.P384(): + return jwa.ES384, nil + case elliptic.P521(): + return jwa.ES512, nil + default: + return "", errors.New("unsupported EC curve for DPoP") + } + case jwa.RSA: + return jwa.RS256, nil + default: + return "", fmt.Errorf("unsupported key type %q for DPoP; only EC and RSA keys are supported", key.KeyType()) + } +} + +// resolveDPoPKey returns the jwk.Key to use for DPoP based on the config. +// Priority: dpopJWK (already set/cached) → dpopKeyPEM (load from PEM) → dpopAlgorithm (generate). +// The resolved key is cached in c.dpopJWK after first resolution. +// Returns (nil, nil) when no custom DPoP key is configured; callers fall back to auto-generated RSA. +// +//nolint:nilnil // nil key signals "use auto-generated RSA path" — not an error condition +func resolveDPoPKey(c *config) (jwk.Key, error) { + if c.dpopJWK != nil { + return c.dpopJWK, nil + } + if len(c.dpopKeyPEM) > 0 { //nolint:nestif // linear priority chain with nested error handling — complexity is inherent + key, err := loadDPoPKeyFromPEM(c.dpopKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to load DPoP key from PEM: %w", err) + } + if c.dpopAlgorithm != "" { + var algVal jwa.SignatureAlgorithm + if err := algVal.Accept(c.dpopAlgorithm); err != nil { + return nil, fmt.Errorf("invalid DPoP algorithm override %q: %w", c.dpopAlgorithm, err) + } + if err := key.Set(jwk.AlgorithmKey, algVal); err != nil { + return nil, fmt.Errorf("failed to apply DPoP algorithm override: %w", err) + } + } + c.dpopJWK = key + return key, nil + } + if c.dpopAlgorithm != "" { + key, err := generateDPoPKeyForAlg(c.dpopAlgorithm) + if err != nil { + return nil, fmt.Errorf("failed to generate DPoP key: %w", err) + } + c.dpopJWK = key + return key, nil + } + return nil, nil +} diff --git a/sdk/dpop_key_test.go b/sdk/dpop_key_test.go new file mode 100644 index 0000000000..1a8606f5eb --- /dev/null +++ b/sdk/dpop_key_test.go @@ -0,0 +1,147 @@ +package sdk + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/lestrrat-go/jwx/v2/jwa" +) + +// jwkToPEMForTest converts a jwk.Key to PEM for round-trip testing. +func jwkToPEMForTest(t *testing.T, key interface{ Raw(any) error }) []byte { + t.Helper() + var raw any + if err := key.Raw(&raw); err != nil { + t.Fatalf("failed to get raw key: %v", err) + } + der, err := x509.MarshalPKCS8PrivateKey(raw) + if err != nil { + t.Fatalf("failed to marshal key to PKCS8: %v", err) + } + return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der}) +} + +func TestGenerateDPoPKeyForAlg_EC(t *testing.T) { + tests := []struct { + alg string + wantAlg jwa.SignatureAlgorithm + curve elliptic.Curve + }{ + {dpopAlgES256, jwa.ES256, elliptic.P256()}, + {dpopAlgES384, jwa.ES384, elliptic.P384()}, + {dpopAlgES512, jwa.ES512, elliptic.P521()}, + } + + for _, tt := range tests { + t.Run(tt.alg, func(t *testing.T) { + key, err := generateDPoPKeyForAlg(tt.alg) + if err != nil { + t.Fatalf("generateDPoPKeyForAlg(%q) error = %v", tt.alg, err) + } + if key.Algorithm() != tt.wantAlg { + t.Errorf("algorithm = %v, want %v", key.Algorithm(), tt.wantAlg) + } + var rawKey *ecdsa.PrivateKey + if err := key.Raw(&rawKey); err != nil { + t.Fatalf("failed to get raw EC key: %v", err) + } + if rawKey.Curve != tt.curve { + t.Errorf("curve = %v, want %v", rawKey.Curve, tt.curve) + } + }) + } +} + +func TestGenerateDPoPKeyForAlg_RSA(t *testing.T) { + tests := []struct { + alg string + wantAlg jwa.SignatureAlgorithm + }{ + {dpopAlgRS256, jwa.RS256}, + {dpopAlgRS384, jwa.RS384}, + {dpopAlgRS512, jwa.RS512}, + } + + for _, tt := range tests { + t.Run(tt.alg, func(t *testing.T) { + key, err := generateDPoPKeyForAlg(tt.alg) + if err != nil { + t.Fatalf("generateDPoPKeyForAlg(%q) error = %v", tt.alg, err) + } + if key.Algorithm() != tt.wantAlg { + t.Errorf("algorithm = %v, want %v", key.Algorithm(), tt.wantAlg) + } + var rawKey *rsa.PrivateKey + if err := key.Raw(&rawKey); err != nil { + t.Fatalf("failed to get raw RSA key: %v", err) + } + }) + } +} + +func TestGenerateDPoPKeyForAlg_Invalid(t *testing.T) { + for _, alg := range []string{"INVALID", "", "HS256", "PS256"} { + t.Run(alg, func(t *testing.T) { + _, err := generateDPoPKeyForAlg(alg) + if err == nil { + t.Errorf("expected error for alg %q, got nil", alg) + } + }) + } +} + +func TestLoadDPoPKeyFromPEM_RSA(t *testing.T) { + generated, err := generateDPoPKeyForAlg(dpopAlgRS256) + if err != nil { + t.Fatalf("failed to generate RSA test key: %v", err) + } + pemBytes := jwkToPEMForTest(t, generated) + + loaded, err := loadDPoPKeyFromPEM(pemBytes) + if err != nil { + t.Fatalf("loadDPoPKeyFromPEM error = %v", err) + } + if loaded.Algorithm() != jwa.RS256 { + t.Errorf("algorithm = %v, want RS256", loaded.Algorithm()) + } +} + +func TestLoadDPoPKeyFromPEM_EC(t *testing.T) { + tests := []struct { + alg string + wantAlg jwa.SignatureAlgorithm + }{ + {dpopAlgES256, jwa.ES256}, + {dpopAlgES384, jwa.ES384}, + {dpopAlgES512, jwa.ES512}, + } + + for _, tt := range tests { + t.Run(tt.alg, func(t *testing.T) { + generated, err := generateDPoPKeyForAlg(tt.alg) + if err != nil { + t.Fatalf("failed to generate EC test key: %v", err) + } + pemBytes := jwkToPEMForTest(t, generated) + + loaded, err := loadDPoPKeyFromPEM(pemBytes) + if err != nil { + t.Fatalf("loadDPoPKeyFromPEM error = %v", err) + } + if loaded.Algorithm() != tt.wantAlg { + t.Errorf("algorithm = %v, want %v", loaded.Algorithm(), tt.wantAlg) + } + }) + } +} + +func TestLoadDPoPKeyFromPEM_InvalidPEM(t *testing.T) { + _, err := loadDPoPKeyFromPEM([]byte("not valid PEM")) + if err == nil { + t.Error("expected error for invalid PEM, got nil") + } +} diff --git a/sdk/idp_access_token_source.go b/sdk/idp_access_token_source.go index 9b84b4022f..ac498cf694 100644 --- a/sdk/idp_access_token_source.go +++ b/sdk/idp_access_token_source.go @@ -108,3 +108,23 @@ func (t *IDPAccessTokenSource) AccessToken(_ context.Context, client *http.Clien func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(t.dpopKey) } + +// newIDPAccessTokenSourceFromJWK creates an IDPAccessTokenSource using a pre-built JWK key. +func newIDPAccessTokenSourceFromJWK( + credentials oauth.ClientCredentials, + idpTokenEndpoint string, + scopes []string, + key jwk.Key, +) (*IDPAccessTokenSource, error) { + endpoint, err := url.Parse(idpTokenEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid url [%s]: %w", idpTokenEndpoint, err) + } + return &IDPAccessTokenSource{ + credentials: credentials, + idpTokenEndpoint: *endpoint, + scopes: scopes, + dpopKey: key, + tokenMutex: &sync.Mutex{}, + }, nil +} diff --git a/sdk/idp_cert_exchange.go b/sdk/idp_cert_exchange.go index 5e0d2f4ee9..68805c1915 100644 --- a/sdk/idp_cert_exchange.go +++ b/sdk/idp_cert_exchange.go @@ -60,3 +60,21 @@ func (c *CertExchangeTokenSource) AccessToken(ctx context.Context, _ *http.Clien func (c *CertExchangeTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(c.key) } + +// newCertExchangeTokenSourceFromJWK creates a CertExchangeTokenSource using a pre-built JWK key. +func newCertExchangeTokenSourceFromJWK( + logger *slog.Logger, + info oauth.CertExchangeInfo, + credentials oauth.ClientCredentials, + idpTokenEndpoint string, + key jwk.Key, +) (auth.AccessTokenSource, error) { + return &CertExchangeTokenSource{ + logger: logger, + info: info, + IdpEndpoint: idpTokenEndpoint, + credentials: credentials, + tokenMutex: &sync.Mutex{}, + key: key, + }, nil +} diff --git a/sdk/idp_oauth_access_token_source.go b/sdk/idp_oauth_access_token_source.go index c9e39859d6..d43f06ebbd 100644 --- a/sdk/idp_oauth_access_token_source.go +++ b/sdk/idp_oauth_access_token_source.go @@ -58,3 +58,12 @@ func (t *OAuthAccessTokenSource) AccessToken(_ context.Context, _ *http.Client) func (t *OAuthAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(t.dpopKey) } + +// newOAuthAccessTokenSourceFromJWK creates an OAuthAccessTokenSource using a pre-built JWK key. +func newOAuthAccessTokenSourceFromJWK(source oauth2.TokenSource, scopes []string, key jwk.Key) *OAuthAccessTokenSource { + return &OAuthAccessTokenSource{ + source: source, + scopes: scopes, + dpopKey: key, + } +} diff --git a/sdk/idp_token_exchange_token_source.go b/sdk/idp_token_exchange_token_source.go index e3295650c8..2884224b1f 100644 --- a/sdk/idp_token_exchange_token_source.go +++ b/sdk/idp_token_exchange_token_source.go @@ -51,3 +51,23 @@ func (i *IDPTokenExchangeTokenSource) AccessToken(ctx context.Context, client *h func (i *IDPTokenExchangeTokenSource) MakeToken(keyMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return i.IDPAccessTokenSource.MakeToken(keyMaker) } + +// newIDPTokenExchangeTokenSourceFromJWK creates an IDPTokenExchangeTokenSource using a pre-built JWK key. +func newIDPTokenExchangeTokenSourceFromJWK( + logger *slog.Logger, + exchangeInfo oauth.TokenExchangeInfo, + credentials oauth.ClientCredentials, + idpTokenEndpoint string, + scopes []string, + key jwk.Key, +) (*IDPTokenExchangeTokenSource, error) { + idpSource, err := newIDPAccessTokenSourceFromJWK(credentials, idpTokenEndpoint, scopes, key) + if err != nil { + return nil, err + } + return &IDPTokenExchangeTokenSource{ + logger: logger, + IDPAccessTokenSource: *idpSource, + TokenExchangeInfo: exchangeInfo, + }, nil +} diff --git a/sdk/options.go b/sdk/options.go index ba63bb092a..cd15db9ae0 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -7,6 +7,7 @@ import ( "net/http" "connectrpc.com/connect" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/auth/oauth" @@ -35,6 +36,9 @@ type config struct { certExchange *oauth.CertExchangeInfo kasSessionKey *ocrypto.RsaKeyPair dpopKey *ocrypto.RsaKeyPair + dpopJWK jwk.Key + dpopAlgorithm string + dpopKeyPEM []byte ipc bool tdfFeatures tdfFeatures customAccessTokenSource auth.AccessTokenSource @@ -228,3 +232,30 @@ func WithLogger(logger *slog.Logger) Option { c.logger = logger } } + +// WithDPoPAlgorithm enables DPoP with an ephemeral key generated for the given algorithm. +// Supported: ES256 (default), ES384, ES512, RS256, RS384, RS512. +// Overrides the auto-generated RSA key used by default. +func WithDPoPAlgorithm(alg string) Option { + return func(c *config) { + c.dpopAlgorithm = alg + } +} + +// WithDPoPKeyPEM enables DPoP using a PEM-encoded private key. Algorithm is inferred +// from the key type unless also overridden via WithDPoPAlgorithm. +// Enables DPoP even without specifying an algorithm. +func WithDPoPKeyPEM(pemBytes []byte) Option { + return func(c *config) { + c.dpopKeyPEM = pemBytes + } +} + +// WithDPoPJWK enables DPoP using a pre-built JWK private key. The JWK must have its +// Algorithm field set. This is the lowest-level DPoP key injection; prefer +// WithDPoPAlgorithm or WithDPoPKeyPEM for most use cases. +func WithDPoPJWK(key jwk.Key) Option { + return func(c *config) { + c.dpopJWK = key + } +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 28eaf86c2c..6766bc5922 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -201,14 +201,22 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return nil, err } - // Wrap HTTP client with DPoP transport for resource requests + // Wrap HTTP client with DPoP transport for resource requests. + // cfg.dpopJWK is populated by resolveDPoPKey (called inside buildIDPTokenSource). httpClient := cfg.httpClient - if accessTokenSource != nil && cfg.dpopKey != nil { - dpopKey, err := getDPoPJWK(cfg.dpopKey) - if err != nil { - return nil, fmt.Errorf("failed to create DPoP JWK: %w", err) + if accessTokenSource != nil { + var dpopKey jwk.Key + if cfg.dpopJWK != nil { + dpopKey = cfg.dpopJWK + } else if cfg.dpopKey != nil { + dpopKey, err = getDPoPJWK(cfg.dpopKey) + if err != nil { + return nil, fmt.Errorf("failed to create DPoP JWK: %w", err) + } + } + if dpopKey != nil { + httpClient = auth.NewDPoPHTTPClient(cfg.httpClient, dpopKey, accessTokenSource, cfg.tokenEndpoint) } - httpClient = auth.NewDPoPHTTPClient(cfg.httpClient, dpopKey, accessTokenSource, cfg.tokenEndpoint) } if accessTokenSource != nil { @@ -292,6 +300,19 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return nil, errors.New("cannot do both token exchange and certificate exchange") } + // Use a user-supplied custom DPoP key (JWK) when present; otherwise fall back to the + // auto-generated RSA key pair (existing behaviour). resolveDPoPKey caches the result + // in c.dpopJWK, so the transport setup in sdk.New() reuses the same key. + customKey, err := resolveDPoPKey(c) + if err != nil { + return nil, fmt.Errorf("failed to resolve DPoP key: %w", err) + } + + if customKey != nil { + return buildIDPTokenSourceFromJWK(c, customKey) + } + + // RSA auto-generation path (no custom DPoP key configured). if c.dpopKey == nil { rsaKeyPair, err := ocrypto.NewRSAKeyPair(dpopKeySize) if err != nil { @@ -301,7 +322,6 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { } var ts auth.AccessTokenSource - var err error switch { case c.oauthAccessTokenSource != nil: @@ -329,6 +349,19 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return ts, err } +func buildIDPTokenSourceFromJWK(c *config, key jwk.Key) (auth.AccessTokenSource, error) { + switch { + case c.oauthAccessTokenSource != nil: + return newOAuthAccessTokenSourceFromJWK(c.oauthAccessTokenSource, c.scopes, key), nil + case c.certExchange != nil: + return newCertExchangeTokenSourceFromJWK(c.logger, *c.certExchange, *c.clientCredentials, c.tokenEndpoint, key) + case c.tokenExchange != nil: + return newIDPTokenExchangeTokenSourceFromJWK(c.logger, *c.tokenExchange, *c.clientCredentials, c.tokenEndpoint, c.scopes, key) + default: + return newIDPAccessTokenSourceFromJWK(*c.clientCredentials, c.tokenEndpoint, c.scopes, key) + } +} + func (s SDK) Close() error { return nil } diff --git a/sdk/version.go b/sdk/version.go index f22769fd2d..2e5b6477ba 100644 --- a/sdk/version.go +++ b/sdk/version.go @@ -14,7 +14,7 @@ const ( // Used by xtest integration harness for feature detection. func SupportedFeatures() []string { return []string{ - "dpop", // RFC 9449 DPoP (Demonstrating Proof-of-Possession) + "dpop", // RFC 9449 DPoP (Demonstrating Proof-of-Possession) "connectrpc", // Connect RPC protocol support } } From 37ed377321239db9fbaf28dd4befa7fda8bde559 Mon Sep 17 00:00:00 2001 From: Dave Mihalcik Date: Wed, 10 Jun 2026 08:33:20 -0400 Subject: [PATCH 3/3] fix(sdk): address code review issues in DPoP transport (DSPX-3397) Fixes critical and high-priority issues identified in PR review: - Fix request body consumed on retry: reset body using GetBody() before retrying - Fix data races: use local base variable instead of modifying t.Base - Fix nonce cache initialization: unconditional lock instead of double-checked lock - Fix missing HTTP client for token source: pass client with base transport to preserve custom configs - Optimize token endpoint URL parsing: cache parsed URL to avoid parsing on every request - Normalize origin casing: lowercase origin in cache to ensure consistent hits on uppercase URLs Co-Authored-By: Claude Sonnet 4.6 --- sdk/auth/dpop_transport.go | 78 +++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/sdk/auth/dpop_transport.go b/sdk/auth/dpop_transport.go index 0d9dd1a19a..766232375c 100644 --- a/sdk/auth/dpop_transport.go +++ b/sdk/auth/dpop_transport.go @@ -36,24 +36,24 @@ type DPoPTransport struct { // and do not include the ath claim. TokenEndpoint string - nonceMu sync.RWMutex - // nonceCache stores server-issued nonces by origin (scheme://host:port) - nonceCache map[string]string + nonceMu sync.RWMutex + nonceCache map[string]string + cachedTokenURL *url.URL + cachedTokenURLStr string } // RoundTrip implements http.RoundTripper, adding DPoP proofs to requests. func (t *DPoPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.Base == nil { - t.Base = http.DefaultTransport + base := t.Base + if base == nil { + base = http.DefaultTransport } + t.nonceMu.Lock() if t.nonceCache == nil { - t.nonceMu.Lock() - if t.nonceCache == nil { - t.nonceCache = make(map[string]string) - } - t.nonceMu.Unlock() + t.nonceCache = make(map[string]string) } + t.nonceMu.Unlock() // Clone request to avoid modifying the original req2 := cloneRequest(req) @@ -66,12 +66,12 @@ func (t *DPoPTransport) RoundTrip(req *http.Request) (*http.Response, error) { nonce := t.getCachedNonce(origin) // Generate and add DPoP proof - if err := t.addDPoPProof(req2, nonce, isTokenRequest); err != nil { + if err := t.addDPoPProof(req2, base, nonce, isTokenRequest); err != nil { return nil, fmt.Errorf("failed to add DPoP proof: %w", err) } // Make the request - resp, err := t.Base.RoundTrip(req2) + resp, err := base.RoundTrip(req2) if err != nil { return resp, err } @@ -94,13 +94,22 @@ func (t *DPoPTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Clone the original request again for retry req3 := cloneRequest(req) + // Reset body using GetBody if available + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to reset request body for retry: %w", err) + } + req3.Body = body + } + // Regenerate proof with nonce - if err := t.addDPoPProof(req3, newNonce, isTokenRequest); err != nil { + if err := t.addDPoPProof(req3, base, newNonce, isTokenRequest); err != nil { return nil, fmt.Errorf("failed to add DPoP proof with nonce: %w", err) } // Retry the request - return t.Base.RoundTrip(req3) + return base.RoundTrip(req3) } } @@ -115,7 +124,7 @@ func (t *DPoPTransport) RoundTrip(req *http.Request) (*http.Response, error) { } // addDPoPProof generates and adds DPoP proof to the request headers. -func (t *DPoPTransport) addDPoPProof(req *http.Request, nonce string, isTokenRequest bool) error { +func (t *DPoPTransport) addDPoPProof(req *http.Request, base http.RoundTripper, nonce string, isTokenRequest bool) error { // Normalize the htu (RFC 9449 HTTP URI Normalization) htu := normalizeURI(req.URL) @@ -134,7 +143,8 @@ func (t *DPoPTransport) addDPoPProof(req *http.Request, nonce string, isTokenReq // For resource requests (not token endpoint), add ath claim var accessToken string if !isTokenRequest && t.TokenSource != nil { - at, err := t.TokenSource.AccessToken(req.Context(), nil) + client := &http.Client{Transport: base} + at, err := t.TokenSource.AccessToken(req.Context(), client) if err != nil { return fmt.Errorf("failed to get access token: %w", err) } @@ -193,13 +203,35 @@ func (t *DPoPTransport) isTokenEndpointRequest(u *url.URL) bool { if t.TokenEndpoint == "" { return false } - tokenURL, err := url.Parse(t.TokenEndpoint) - if err != nil { + + t.nonceMu.RLock() + cachedURL := t.cachedTokenURL + cachedStr := t.cachedTokenURLStr + t.nonceMu.RUnlock() + + if cachedStr != t.TokenEndpoint { + t.nonceMu.Lock() + if t.cachedTokenURLStr != t.TokenEndpoint { + parsed, err := url.Parse(t.TokenEndpoint) + if err == nil { + t.cachedTokenURL = parsed + t.cachedTokenURLStr = t.TokenEndpoint + } else { + t.cachedTokenURL = nil + t.cachedTokenURLStr = "" + } + } + cachedURL = t.cachedTokenURL + t.nonceMu.Unlock() + } + + if cachedURL == nil { return false } - return u.Scheme == tokenURL.Scheme && - u.Host == tokenURL.Host && - u.Path == tokenURL.Path + + return u.Scheme == cachedURL.Scheme && + u.Host == cachedURL.Host && + u.Path == cachedURL.Path } // normalizeURI normalizes the URI per RFC 9449 HTTP URI Normalization: @@ -219,9 +251,9 @@ func normalizeURI(u *url.URL) string { return fmt.Sprintf("%s://%s%s", scheme, host, u.Path) } -// getOrigin returns the origin (scheme://host:port) from a URL. +// getOrigin returns the origin (scheme://host:port) from a URL, normalized to lowercase. func getOrigin(u *url.URL) string { - return fmt.Sprintf("%s://%s", u.Scheme, u.Host) + return strings.ToLower(fmt.Sprintf("%s://%s", u.Scheme, u.Host)) } // getCachedNonce retrieves the cached nonce for an origin.