Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions examples/cmd/benchmark_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
package cmd

import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"fmt"
"io"
"net/http"
"os"
"sync"
"time"

Expand All @@ -13,7 +18,6 @@ import (
kasp "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/protocol/go/kas/kasconnect"
"github.com/opentdf/platform/protocol/go/policy"

"github.com/opentdf/platform/sdk/experimental/tdf"
"github.com/opentdf/platform/sdk/httputil"
"github.com/spf13/cobra"
Expand All @@ -35,7 +39,7 @@ func init() {
//nolint: mnd // no magic number, this is just default value for payload size
benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB
//nolint: mnd // same as above
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunk size") // Default 16KB
ExamplesCmd.AddCommand(benchmarkCmd)
}

Expand All @@ -46,16 +50,21 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
return fmt.Errorf("failed to generate random payload: %w", err)
}

http := httputil.SafeHTTPClient()
var httpClient *http.Client
if insecureSkipVerify {
httpClient = httputil.SafeHTTPClientWithTLSConfig(&tls.Config{InsecureSkipVerify: true}) //nolint:gosec // user-requested flag
} else {
httpClient = httputil.SafeHTTPClient()
}
fmt.Println("endpoint:", platformEndpoint)
serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint)
serviceClient := kasconnect.NewAccessServiceClient(httpClient, platformEndpoint)
resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)}))
if err != nil {
return fmt.Errorf("failed to get public key from KAS: %w", err)
}
var attrs []*policy.Value

simpleyKey := &policy.SimpleKasKey{
simpleKey := &policy.SimpleKasKey{
KasUri: platformEndpoint,
KasId: "id",
PublicKey: &policy.SimpleKasPublicKey{
Expand All @@ -65,29 +74,31 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
},
}

attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
if err != nil {
return fmt.Errorf("failed to create writer: %w", err)
}
i := 0
segs := (len(payload) + segmentChunk - 1) / segmentChunk
segResults := make([]*tdf.SegmentResult, segs)
wg := sync.WaitGroup{}
segs := len(payload) / segmentChunk
wg.Add(segs)
start := time.Now()
for i < segs {
segment := i
go func() {
start := i * segmentChunk
end := min(start+segmentChunk, len(payload))
_, err = writer.WriteSegment(context.Background(), segment, payload[start:end])
if err != nil {
fmt.Println(err)
panic(err)
for i := 0; i < segs; i++ {
segStart := i * segmentChunk
segEnd := min(segStart+segmentChunk, len(payload))
// Copy the chunk: EncryptInPlace overwrites the input buffer and
// appends a 16-byte auth tag, which would corrupt adjacent segments.
chunk := make([]byte, segEnd-segStart)
copy(chunk, payload[segStart:segEnd])
go func(index int, data []byte) {
defer wg.Done()
sr, serr := writer.WriteSegment(context.Background(), index, data)
if serr != nil {
panic(serr)
}
wg.Done()
}()
i++
segResults[index] = sr
}(i, chunk)
}
wg.Wait()

Expand All @@ -98,12 +109,48 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
}
totalTime := end.Sub(start)

// Assemble the complete TDF: segment data (in order) + finalize data
var tdfBuf bytes.Buffer
for i, sr := range segResults {
if _, err := io.Copy(&tdfBuf, sr.TDFData); err != nil {
return fmt.Errorf("failed to read segment %d TDF data: %w", i, err)
}
}
tdfBuf.Write(result.Data)

outPath := "/tmp/benchmark-experimental.tdf"
if err := os.WriteFile(outPath, tdfBuf.Bytes(), 0o600); err != nil {
return fmt.Errorf("failed to write TDF: %w", err)
}

fmt.Printf("# Benchmark Experimental TDF Writer Results:\n")
fmt.Printf("| Metric | Value |\n")
fmt.Printf("|--------------------|--------------|\n")
fmt.Printf("| Payload Size (B) | %d |\n", payloadSize)
fmt.Printf("| Output Size (B) | %d |\n", len(result.Data))
fmt.Printf("| Output Size (B) | %d |\n", tdfBuf.Len())
fmt.Printf("| Total Time | %s |\n", totalTime)
fmt.Printf("| TDF saved to | %s |\n", outPath)

// Decrypt with production SDK to verify interoperability
s, err := newSDK()
if err != nil {
return fmt.Errorf("failed to create SDK: %w", err)
}
defer s.Close()
tdfReader, err := s.LoadTDF(bytes.NewReader(tdfBuf.Bytes()))
if err != nil {
return fmt.Errorf("failed to load TDF with production SDK: %w", err)
}
var decrypted bytes.Buffer
if _, err = io.Copy(&decrypted, tdfReader); err != nil {
return fmt.Errorf("failed to decrypt TDF with production SDK: %w", err)
}

if bytes.Equal(payload, decrypted.Bytes()) {
fmt.Println("| Decrypt Verify | PASS - roundtrip matches |")
} else {
fmt.Printf("| Decrypt Verify | FAIL - payload %d bytes, decrypted %d bytes |\n", len(payload), decrypted.Len())
}

return nil
}
16 changes: 16 additions & 0 deletions sdk/experimental/tdf/keysplit/xor_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ func (x *XORSplitter) GenerateSplits(_ context.Context, attrs []*policy.Value, d
// 4. Collect all public keys from assignments
allKeys := collectAllPublicKeys(assignments)

// 5. Merge the default KAS public key if not already present.
// Attribute grants may reference the default KAS URL without including the public key
// (e.g., legacy grants with only a URI). The default KAS key fills this gap.
if x.config.defaultKAS != nil && x.config.defaultKAS.GetPublicKey() != nil {
kasURL := x.config.defaultKAS.GetKasUri()
if _, exists := allKeys[kasURL]; !exists {
pubKey := x.config.defaultKAS.GetPublicKey()
allKeys[kasURL] = KASPublicKey{
URL: kasURL,
KID: pubKey.GetKid(),
PEM: pubKey.GetPem(),
Algorithm: formatAlgorithm(pubKey.GetAlgorithm()),
}
}
}
Comment thread
pflynn-virtru marked this conversation as resolved.

slog.Debug("completed key split generation",
slog.Int("num_splits", len(splits)),
slog.Int("num_kas_keys", len(allKeys)))
Expand Down
70 changes: 70 additions & 0 deletions sdk/experimental/tdf/keysplit/xor_splitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,73 @@ func TestXORSplitter_ComplexScenarios(t *testing.T) {
assert.True(t, found, "Should find split with multiple KAS URLs")
})
}

// TestXORSplitter_DefaultKASMergedForURIOnlyGrant is a regression test
// ensuring that when an attribute grant references a KAS URL without
// embedding the public key (URI-only legacy grant), the default KAS's
// full public key info is merged into the result. Without the merge fix
// in GenerateSplits, collectAllPublicKeys returns an incomplete map and
// key wrapping fails.
func TestXORSplitter_DefaultKASMergedForURIOnlyGrant(t *testing.T) {
defaultKAS := &policy.SimpleKasKey{
KasUri: kasUs,
PublicKey: &policy.SimpleKasPublicKey{
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
Kid: "default-kid",
Pem: mockRSAPublicKey1,
},
}
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))

dek := make([]byte, 32)
_, err := rand.Read(dek)
require.NoError(t, err)

// Create an attribute whose grant references kasUs by URI only (no KasKeys).
attr := createMockValue("https://test.com/attr/level/value/secret", "", "", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
attr.Grants = []*policy.KeyAccessServer{
{Uri: kasUs}, // URI-only, no embedded public key
}

result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
require.NoError(t, err)
require.NotNil(t, result)

// The default KAS public key must be merged into the result.
require.Contains(t, result.KASPublicKeys, kasUs, "default KAS key should be merged for URI-only grant")
pubKey := result.KASPublicKeys[kasUs]
assert.Equal(t, "default-kid", pubKey.KID)
assert.Equal(t, mockRSAPublicKey1, pubKey.PEM)
assert.Equal(t, "rsa:2048", pubKey.Algorithm)
}

// TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey verifies that when
// an attribute grant already embeds a full public key for the same KAS URL
// as the default, the grant's key is preserved and not overwritten.
func TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey(t *testing.T) {
defaultKAS := &policy.SimpleKasKey{
KasUri: kasUs,
PublicKey: &policy.SimpleKasPublicKey{
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
Kid: "default-kid",
Pem: mockRSAPublicKey1,
},
}
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))

dek := make([]byte, 32)
_, err := rand.Read(dek)
require.NoError(t, err)

// Create an attribute with a fully-embedded grant for the same KAS URL
// but with a different KID.
attr := createMockValue("https://test.com/attr/level/value/secret", kasUs, "grant-kid", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)

result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
require.NoError(t, err)
require.NotNil(t, result)

require.Contains(t, result.KASPublicKeys, kasUs)
pubKey := result.KASPublicKeys[kasUs]
assert.Equal(t, "grant-kid", pubKey.KID, "grant's key should not be overwritten by default KAS")
}
6 changes: 5 additions & 1 deletion sdk/experimental/tdf/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*Seg
if err != nil {
return nil, err
}
segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's
// Hash must cover nonce + cipher to match the standard SDK reader's verification.
// The standard SDK's Encrypt() returns nonce prepended to cipher and hashes that;
// EncryptInPlace() returns them separately, so we must concatenate for hashing.
segmentData := append(nonce, segmentCipher...) //nolint:gocritic // nonce cap == len, so always allocates
segmentSig, err := calculateSignature(segmentData, w.dek, w.segmentIntegrityAlgorithm, false)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading