From 6b2aa5bb39d6501965400e181944614886daa505 Mon Sep 17 00:00:00 2001 From: Tyler Biscoe Date: Wed, 10 Jun 2026 12:10:38 -0400 Subject: [PATCH] policy binding hex fix --- sdk/tdf.go | 13 ++++++-- sdk/tdf_test.go | 51 ++++++++++++++++++++++--------- service/kas/access/rewrap.go | 41 ++++++++++++++----------- service/kas/access/rewrap_test.go | 39 +++++++++++++++++++++++ 4 files changed, 108 insertions(+), 36 deletions(-) diff --git a/sdk/tdf.go b/sdk/tdf.go index 1d6ef1182c..299c08264f 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -587,8 +587,15 @@ func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFCon symKeys = append(symKeys, symKey) // policy binding - policyBindingHash := hex.EncodeToString(ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject)) - pbstring := string(ocrypto.Base64Encode([]byte(policyBindingHash))) + // Spec (>= 4.3.0) requires Base64(HMAC); pre-4.3.0 TDFs used Base64(hex(HMAC)). + // useHex tracks the same threshold as segment/root signatures via WithTargetMode. + hmacBytes := ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject) + var pbstring string + if tdfConfig.useHex { + pbstring = string(ocrypto.Base64Encode([]byte(hex.EncodeToString(hmacBytes)))) + } else { + pbstring = string(ocrypto.Base64Encode(hmacBytes)) + } policyBinding := PolicyBinding{ Alg: "HS256", Hash: pbstring, @@ -1247,7 +1254,7 @@ func createRewrapRequest(_ context.Context, r *Reader) (map[string]*kas.Unsigned invalidPolicy = !ok alg, ok = policyBinding["alg"].(string) invalidPolicy = invalidPolicy || !ok - case (PolicyBinding): + case PolicyBinding: hash = policyBinding.Hash alg = policyBinding.Alg default: diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 650fce17b1..40346b5762 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -524,6 +524,7 @@ func (s *TDFSuite) Test_SimpleTDF() { WithSessionKeyType(ocrypto.EC256Key), WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}), }, + expectedSize: 1970, }, { name: "target-mode-0", @@ -659,6 +660,24 @@ func (s *TDFSuite) Test_SimpleTDF() { s.Require().Error(err2) } + // check that the policy binding matches the same hex/raw scheme as the + // other signatures. Spec >= 4.3.0 emits Base64(HMAC); pre-4.3.0 emits + // Base64(hex(HMAC)). + s.Require().NotEmpty(r.Manifest().KeyAccessObjs) + pb, ok := r.Manifest().KeyAccessObjs[0].PolicyBinding.(map[string]any) + s.Require().True(ok, "expected PolicyBinding to deserialize as map") + pbHash, ok := pb["hash"].(string) + s.Require().True(ok, "expected PolicyBinding.hash to be a string") + decodedPB, err := ocrypto.Base64Decode([]byte(pbHash)) + s.Require().NoError(err) + if config.useHex { + s.Len(decodedPB, hex.EncodedLen(sha256.Size), "legacy policy binding should be hex-encoded HMAC") + _, err = hex.DecodeString(string(decodedPB)) + s.Require().NoError(err) + } else { + s.Len(decodedPB, sha256.Size, "spec-compliant policy binding should be raw HMAC bytes") + } + // check version is present if usehex is false if config.useHex { s.Empty(r.Manifest().TDFVersion) @@ -927,7 +946,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { }, verifiers: nil, disableAssertionVerification: false, - expectedSize: 2689, + expectedSize: 2656, }, { assertions: []AssertionConfig{ @@ -990,7 +1009,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { DefaultKey: defaultKey, }, disableAssertionVerification: false, - expectedSize: 2689, + expectedSize: 2656, }, { assertions: []AssertionConfig{ @@ -1039,7 +1058,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { }, }, disableAssertionVerification: false, - expectedSize: 2988, + expectedSize: 2955, }, { assertions: []AssertionConfig{ @@ -1079,7 +1098,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { }, }, disableAssertionVerification: false, - expectedSize: 2689, + expectedSize: 2656, }, { assertions: []AssertionConfig{ @@ -1096,7 +1115,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { }, }, disableAssertionVerification: true, - expectedSize: 2180, + expectedSize: 2147, }, } { expectedTdfSize := test.expectedSize @@ -1333,7 +1352,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { SigningKey: defaultKey, }, }, - expectedSize: 2689, + expectedSize: 2656, }, { assertions: []AssertionConfig{ @@ -1381,7 +1400,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { }, }, }, - expectedSize: 2988, + expectedSize: 2955, }, { assertions: []AssertionConfig{ @@ -1415,7 +1434,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { verifiers: &AssertionVerificationKeys{ DefaultKey: defaultKey, }, - expectedSize: 2689, + expectedSize: 2656, }, } { expectedTdfSize := test.expectedSize @@ -1955,7 +1974,7 @@ func (s *TDFSuite) Test_KeySplit_SameKas_SameAlgorithm() { { n: "multiple-keys-same-kas-same-algorithm", fileSize: 5, - tdfFileSize: 2581, + tdfFileSize: 2445, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", }, } { @@ -2034,7 +2053,7 @@ func (s *TDFSuite) Test_KeySplits() { { n: "shared", fileSize: 5, - tdfFileSize: 2759, + tdfFileSize: 2635, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, @@ -2045,7 +2064,7 @@ func (s *TDFSuite) Test_KeySplits() { { n: "split", fileSize: 5, - tdfFileSize: 2759, + tdfFileSize: 2635, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, @@ -2056,7 +2075,7 @@ func (s *TDFSuite) Test_KeySplits() { { n: "mixture", fileSize: 5, - tdfFileSize: 3351, + tdfFileSize: 3191, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, @@ -2738,7 +2757,7 @@ func (s *TDFSuite) testDecryptWithReader(sdk *SDK, tdfFile, decryptedTdfFileName resultBuf := bytes.Repeat([]byte{char}, int(bufSize)) // read last 5 bytes - n, err := r.ReadAt(buf, test.fileSize-(bufSize)) + n, err := r.ReadAt(buf, test.fileSize-bufSize) if err != nil { s.Require().ErrorIs(err, io.EOF) } @@ -2858,7 +2877,8 @@ func (s *TDFSuite) startBackend() { ats := getTokenSource(s.T()) - sdk, err := New(sdkPlatformURL, + sdk, err := New( + sdkPlatformURL, WithClientCredentials("test", "test", nil), withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), @@ -2894,7 +2914,8 @@ func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *connect for _, fqn := range in.Msg.GetFqns() { av, err := NewAttributeValueFQN(fqn) if err != nil { - slog.Error("invalid fqn", + slog.Error( + "invalid fqn", slog.String("fqn", fqn), slog.Any("error", err), ) diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index aba7ffa84e..e088adaa24 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -431,12 +431,7 @@ func verifyPolicyBinding(ctx context.Context, policy []byte, kao *kaspb.Unsigned } policyBinding := kao.GetKeyAccessObject().GetPolicyBinding().GetHash() - expectedHMAC := make([]byte, base64.StdEncoding.DecodedLen(len(policyBinding))) - n, err := base64.StdEncoding.Decode(expectedHMAC, []byte(policyBinding)) - if err == nil { - n, err = hex.Decode(expectedHMAC, expectedHMAC[:n]) - } - expectedHMAC = expectedHMAC[:n] + expectedHMAC, err := decodePolicyBinding(policyBinding) if err != nil { logger.WarnContext(ctx, "invalid policy binding", slog.Any("error", err)) return err400("bad request") @@ -450,6 +445,27 @@ func verifyPolicyBinding(ctx context.Context, policy []byte, kao *kaspb.Unsigned return nil } +// decodePolicyBinding decodes the policy binding hash from its on-wire form to +// the raw HMAC bytes used for comparison. It accepts both the spec-compliant +// Base64(HMAC) encoding (>= TDF 4.3.0) and the legacy Base64(hex(HMAC)) +// encoding emitted by pre-4.3.0 writers. The two are unambiguous by length +// after base64 decode (32 vs 64 bytes), so no version signal is needed. +func decodePolicyBinding(b64Hash string) ([]byte, error) { + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(b64Hash))) + n, err := base64.StdEncoding.Decode(decoded, []byte(b64Hash)) + if err != nil { + return nil, err + } + decoded = decoded[:n] + if n == hex.EncodedLen(sha256.Size) { + dehexed := make([]byte, hex.DecodedLen(n)) + if _, decErr := hex.Decode(dehexed, decoded); decErr == nil { + return dehexed, nil + } + } + return decoded, nil +} + func extractPolicyBinding(policyBinding interface{}) (string, error) { switch v := policyBinding.(type) { case string: @@ -805,23 +821,12 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned } // Store policy binding in context for verification - policyBindingB64Encoded := kao.GetKeyAccessObject().GetPolicyBinding().GetHash() - policyBinding := make([]byte, base64.StdEncoding.DecodedLen(len(policyBindingB64Encoded))) - n, err := base64.StdEncoding.Decode(policyBinding, []byte(policyBindingB64Encoded)) + policyBinding, err := decodePolicyBinding(kao.GetKeyAccessObject().GetPolicyBinding().GetHash()) if err != nil { p.Logger.WarnContext(ctx, "invalid policy binding encoding", slog.Any("error", err)) failedKAORewrap(results, kao, err400("bad request")) // Generic: malformed binding may indicate tamper continue } - if n == 64 { //nolint:mnd // 32 bytes of hex encoded data = 256 bit sha-2 - // Sometimes the policy binding is a b64 encoded hex encoded string - // Decode it again if so. - dehexed := make([]byte, hex.DecodedLen(n)) - _, err = hex.Decode(dehexed, policyBinding[:n]) - if err == nil { - policyBinding = dehexed - } - } // Verify policy binding using the UnwrappedKeyData interface if err := dek.VerifyBinding(ctx, []byte(req.GetPolicy().GetBody()), policyBinding); err != nil { diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index bd62647483..a2f9d4676c 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -1431,3 +1431,42 @@ func TestVerifyRewrapRequests(t *testing.T) { }) } } + +func TestDecodePolicyBinding(t *testing.T) { + rawHMAC := bytes.Repeat([]byte{0xab}, 32) + + tests := []struct { + name string + input string + want []byte + wantErr bool + }{ + { + name: "spec-compliant raw HMAC (Base64(HMAC))", + input: base64.StdEncoding.EncodeToString(rawHMAC), + want: rawHMAC, + }, + { + name: "legacy hex-encoded HMAC (Base64(hex(HMAC)))", + input: base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(rawHMAC))), + want: rawHMAC, + }, + { + name: "invalid base64", + input: "not!valid!base64", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodePolicyBinding(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +}