Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,15 @@ func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFCon
symKeys = append(symKeys, symKey)

// policy binding
policyBindingHash := hex.EncodeToString(ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject))
pbstring := string(ocrypto.Base64Encode([]byte(policyBindingHash)))
// Spec (>= 4.3.0) requires Base64(HMAC); pre-4.3.0 TDFs used Base64(hex(HMAC)).
// useHex tracks the same threshold as segment/root signatures via WithTargetMode.
hmacBytes := ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject)
var pbstring string
if tdfConfig.useHex {
pbstring = string(ocrypto.Base64Encode([]byte(hex.EncodeToString(hmacBytes))))
} else {
pbstring = string(ocrypto.Base64Encode(hmacBytes))
}
policyBinding := PolicyBinding{
Alg: "HS256",
Hash: pbstring,
Expand Down Expand Up @@ -1247,7 +1254,7 @@ func createRewrapRequest(_ context.Context, r *Reader) (map[string]*kas.Unsigned
invalidPolicy = !ok
alg, ok = policyBinding["alg"].(string)
invalidPolicy = invalidPolicy || !ok
case (PolicyBinding):
case PolicyBinding:
hash = policyBinding.Hash
alg = policyBinding.Alg
default:
Expand Down
51 changes: 36 additions & 15 deletions sdk/tdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ func (s *TDFSuite) Test_SimpleTDF() {
WithSessionKeyType(ocrypto.EC256Key),
WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}),
},
expectedSize: 1970,
},
{
name: "target-mode-0",
Expand Down Expand Up @@ -659,6 +660,24 @@ func (s *TDFSuite) Test_SimpleTDF() {
s.Require().Error(err2)
}

// check that the policy binding matches the same hex/raw scheme as the
// other signatures. Spec >= 4.3.0 emits Base64(HMAC); pre-4.3.0 emits
// Base64(hex(HMAC)).
s.Require().NotEmpty(r.Manifest().KeyAccessObjs)
pb, ok := r.Manifest().KeyAccessObjs[0].PolicyBinding.(map[string]any)
s.Require().True(ok, "expected PolicyBinding to deserialize as map")
pbHash, ok := pb["hash"].(string)
s.Require().True(ok, "expected PolicyBinding.hash to be a string")
decodedPB, err := ocrypto.Base64Decode([]byte(pbHash))
s.Require().NoError(err)
if config.useHex {
s.Len(decodedPB, hex.EncodedLen(sha256.Size), "legacy policy binding should be hex-encoded HMAC")
_, err = hex.DecodeString(string(decodedPB))
s.Require().NoError(err)
} else {
s.Len(decodedPB, sha256.Size, "spec-compliant policy binding should be raw HMAC bytes")
}

// check version is present if usehex is false
if config.useHex {
s.Empty(r.Manifest().TDFVersion)
Expand Down Expand Up @@ -927,7 +946,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
verifiers: nil,
disableAssertionVerification: false,
expectedSize: 2689,
expectedSize: 2656,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -990,7 +1009,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
DefaultKey: defaultKey,
},
disableAssertionVerification: false,
expectedSize: 2689,
expectedSize: 2656,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -1039,7 +1058,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 2988,
expectedSize: 2955,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -1079,7 +1098,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 2689,
expectedSize: 2656,
},
{
assertions: []AssertionConfig{
Expand All @@ -1096,7 +1115,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: true,
expectedSize: 2180,
expectedSize: 2147,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -1333,7 +1352,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
SigningKey: defaultKey,
},
},
expectedSize: 2689,
expectedSize: 2656,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -1381,7 +1400,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
},
},
},
expectedSize: 2988,
expectedSize: 2955,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -1415,7 +1434,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
verifiers: &AssertionVerificationKeys{
DefaultKey: defaultKey,
},
expectedSize: 2689,
expectedSize: 2656,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -1955,7 +1974,7 @@ func (s *TDFSuite) Test_KeySplit_SameKas_SameAlgorithm() {
{
n: "multiple-keys-same-kas-same-algorithm",
fileSize: 5,
tdfFileSize: 2581,
tdfFileSize: 2445,
checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2",
},
} {
Expand Down Expand Up @@ -2034,7 +2053,7 @@ func (s *TDFSuite) Test_KeySplits() {
{
n: "shared",
fileSize: 5,
tdfFileSize: 2759,
tdfFileSize: 2635,
checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2",
splitPlan: []keySplitStep{
{KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"},
Expand All @@ -2045,7 +2064,7 @@ func (s *TDFSuite) Test_KeySplits() {
{
n: "split",
fileSize: 5,
tdfFileSize: 2759,
tdfFileSize: 2635,
checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2",
splitPlan: []keySplitStep{
{KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"},
Expand All @@ -2056,7 +2075,7 @@ func (s *TDFSuite) Test_KeySplits() {
{
n: "mixture",
fileSize: 5,
tdfFileSize: 3351,
tdfFileSize: 3191,
checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2",
splitPlan: []keySplitStep{
{KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"},
Expand Down Expand Up @@ -2738,7 +2757,7 @@ func (s *TDFSuite) testDecryptWithReader(sdk *SDK, tdfFile, decryptedTdfFileName
resultBuf := bytes.Repeat([]byte{char}, int(bufSize))

// read last 5 bytes
n, err := r.ReadAt(buf, test.fileSize-(bufSize))
n, err := r.ReadAt(buf, test.fileSize-bufSize)
if err != nil {
s.Require().ErrorIs(err, io.EOF)
}
Expand Down Expand Up @@ -2858,7 +2877,8 @@ func (s *TDFSuite) startBackend() {

ats := getTokenSource(s.T())

sdk, err := New(sdkPlatformURL,
sdk, err := New(
sdkPlatformURL,
WithClientCredentials("test", "test", nil),
withCustomAccessTokenSource(&ats),
WithTokenEndpoint("http://localhost:65432/auth/token"),
Expand Down Expand Up @@ -2894,7 +2914,8 @@ func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *connect
for _, fqn := range in.Msg.GetFqns() {
av, err := NewAttributeValueFQN(fqn)
if err != nil {
slog.Error("invalid fqn",
slog.Error(
"invalid fqn",
slog.String("fqn", fqn),
slog.Any("error", err),
)
Expand Down
41 changes: 23 additions & 18 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,7 @@ func verifyPolicyBinding(ctx context.Context, policy []byte, kao *kaspb.Unsigned
}

policyBinding := kao.GetKeyAccessObject().GetPolicyBinding().GetHash()
expectedHMAC := make([]byte, base64.StdEncoding.DecodedLen(len(policyBinding)))
n, err := base64.StdEncoding.Decode(expectedHMAC, []byte(policyBinding))
if err == nil {
n, err = hex.Decode(expectedHMAC, expectedHMAC[:n])
}
expectedHMAC = expectedHMAC[:n]
expectedHMAC, err := decodePolicyBinding(policyBinding)
if err != nil {
logger.WarnContext(ctx, "invalid policy binding", slog.Any("error", err))
return err400("bad request")
Expand All @@ -450,6 +445,27 @@ func verifyPolicyBinding(ctx context.Context, policy []byte, kao *kaspb.Unsigned
return nil
}

// decodePolicyBinding decodes the policy binding hash from its on-wire form to
// the raw HMAC bytes used for comparison. It accepts both the spec-compliant
// Base64(HMAC) encoding (>= TDF 4.3.0) and the legacy Base64(hex(HMAC))
// encoding emitted by pre-4.3.0 writers. The two are unambiguous by length
// after base64 decode (32 vs 64 bytes), so no version signal is needed.
func decodePolicyBinding(b64Hash string) ([]byte, error) {
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(b64Hash)))
n, err := base64.StdEncoding.Decode(decoded, []byte(b64Hash))
if err != nil {
return nil, err
}
decoded = decoded[:n]
if n == hex.EncodedLen(sha256.Size) {
dehexed := make([]byte, hex.DecodedLen(n))
if _, decErr := hex.Decode(dehexed, decoded); decErr == nil {
return dehexed, nil
}
}
return decoded, nil
}

func extractPolicyBinding(policyBinding interface{}) (string, error) {
switch v := policyBinding.(type) {
case string:
Expand Down Expand Up @@ -805,23 +821,12 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
}

// Store policy binding in context for verification
policyBindingB64Encoded := kao.GetKeyAccessObject().GetPolicyBinding().GetHash()
policyBinding := make([]byte, base64.StdEncoding.DecodedLen(len(policyBindingB64Encoded)))
n, err := base64.StdEncoding.Decode(policyBinding, []byte(policyBindingB64Encoded))
policyBinding, err := decodePolicyBinding(kao.GetKeyAccessObject().GetPolicyBinding().GetHash())
if err != nil {
p.Logger.WarnContext(ctx, "invalid policy binding encoding", slog.Any("error", err))
failedKAORewrap(results, kao, err400("bad request")) // Generic: malformed binding may indicate tamper
continue
}
if n == 64 { //nolint:mnd // 32 bytes of hex encoded data = 256 bit sha-2
// Sometimes the policy binding is a b64 encoded hex encoded string
// Decode it again if so.
dehexed := make([]byte, hex.DecodedLen(n))
_, err = hex.Decode(dehexed, policyBinding[:n])
if err == nil {
policyBinding = dehexed
}
}

// Verify policy binding using the UnwrappedKeyData interface
if err := dek.VerifyBinding(ctx, []byte(req.GetPolicy().GetBody()), policyBinding); err != nil {
Expand Down
39 changes: 39 additions & 0 deletions service/kas/access/rewrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,3 +1431,42 @@ func TestVerifyRewrapRequests(t *testing.T) {
})
}
}

func TestDecodePolicyBinding(t *testing.T) {
rawHMAC := bytes.Repeat([]byte{0xab}, 32)

tests := []struct {
name string
input string
want []byte
wantErr bool
}{
{
name: "spec-compliant raw HMAC (Base64(HMAC))",
input: base64.StdEncoding.EncodeToString(rawHMAC),
want: rawHMAC,
},
{
name: "legacy hex-encoded HMAC (Base64(hex(HMAC)))",
input: base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(rawHMAC))),
want: rawHMAC,
},
{
name: "invalid base64",
input: "not!valid!base64",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := decodePolicyBinding(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
Loading