diff --git a/bolt12/decode.go b/bolt12/decode.go new file mode 100644 index 00000000000..4307455b808 --- /dev/null +++ b/bolt12/decode.go @@ -0,0 +1,29 @@ +package bolt12 + +import ( + "bytes" + "fmt" + + "github.com/lightningnetwork/lnd/tlv" +) + +// decodeStream runs a single typed-stream pass over data and returns the +// canonical TypeMap. Records may be passed in any order; NewStream requires +// them sorted, so SortRecords runs first. +func decodeStream(data []byte, records ...tlv.Record) (tlv.TypeMap, error) { + tlv.SortRecords(records) + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, fmt.Errorf("create stream: %w", err) + } + + typeMap, err := stream.DecodeWithParsedTypesP2P( + bytes.NewReader(data), + ) + if err != nil { + return nil, fmt.Errorf("decode stream: %w", err) + } + + return typeMap, nil +} diff --git a/bolt12/doc.go b/bolt12/doc.go new file mode 100644 index 00000000000..c58a1ced1d3 --- /dev/null +++ b/bolt12/doc.go @@ -0,0 +1,19 @@ +// Package bolt12 implements encoding, decoding, and validation for BOLT 12 +// Offers, Invoice Requests, and Invoices. It provides a pure codec library +// with no LND daemon dependencies. +// +// BOLT 12 messages use TLV streams encoded with a checksumless bech32 variant +// and signed with BIP-340 Schnorr signatures over a Merkle tree of TLV fields. +// +// Human-readable prefixes: +// - lno: Offer +// - lnr: Invoice Request +// - lni: Invoice +// +// # Codec Contract +// +// Encode validates before serialising and refuses to emit bytes that would fail +// the writer requirements, invalid bytes are unrepresentable on the wire. +// Low-level decoders stay permissive so diagnostic and fuzz harnesses can +// inspect malformed input. +package bolt12 diff --git a/bolt12/helpers_test.go b/bolt12/helpers_test.go new file mode 100644 index 00000000000..78bbdfdffd2 --- /dev/null +++ b/bolt12/helpers_test.go @@ -0,0 +1,24 @@ +package bolt12 + +import ( + "bytes" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// bobKey returns the deterministic spec test key for Bob, whose 32-byte scalar +// is 0x42 repeated. Used across signature and round-trip tests so the same key +// is not reconstructed in every callsite. +func bobKey() (*btcec.PrivateKey, *btcec.PublicKey) { + priv, pub := btcec.PrivKeyFromBytes(bytes.Repeat([]byte{0x42}, 32)) + + return priv, pub +} + +// aliceKey returns the deterministic spec test key for Alice, whose 32-byte +// scalar is 0x41 repeated. +func aliceKey() (*btcec.PrivateKey, *btcec.PublicKey) { + priv, pub := btcec.PrivKeyFromBytes(bytes.Repeat([]byte{0x41}, 32)) + + return priv, pub +} diff --git a/bolt12/offer.go b/bolt12/offer.go new file mode 100644 index 00000000000..476efe7488a --- /dev/null +++ b/bolt12/offer.go @@ -0,0 +1,167 @@ +package bolt12 + +import ( + "bytes" + "fmt" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// Offer represents a BOLT 12 offer message. An offer is a long-lived, reusable +// payment template that can generate multiple invoices. +type Offer struct { + // OfferChains specifies which chains this offer is valid for. If + // absent, bitcoin is implied. + OfferChains tlv.OptionalRecordT[tlv.TlvType2, ChainsRecord] + + // OfferMetadata is opaque data set by the offer creator for its own + // use. + OfferMetadata tlv.OptionalRecordT[tlv.TlvType4, tlv.Blob] + + // OfferCurrency is the ISO 4217 currency code for the offer amount, if + // the amount is not in the chain's native unit. + OfferCurrency tlv.OptionalRecordT[tlv.TlvType6, tlv.Blob] + + // OfferAmount is the amount expected per item, encoded as a tu64. The + // unit depends on OfferCurrency (msat if absent). + OfferAmount tlv.OptionalRecordT[tlv.TlvType8, TUint64] + + // OfferDescription is a UTF-8 description of the purpose of the + // payment. + OfferDescription tlv.OptionalRecordT[tlv.TlvType10, tlv.Blob] + + // OfferFeatures is the feature bit vector for this offer. + OfferFeatures tlv.OptionalRecordT[tlv.TlvType12, + lnwire.RawFeatureVector] + + // OfferAbsoluteExpiry is the time (seconds since epoch) after which the + // offer should not be used, encoded as a tu64. + OfferAbsoluteExpiry tlv.OptionalRecordT[tlv.TlvType14, TUint64] + + // OfferPaths contains one or more blinded paths to the offer issuer. + OfferPaths tlv.OptionalRecordT[tlv.TlvType16, lnwire.BlindedPaths] + + // OfferIssuer is a UTF-8 string identifying the issuer. + OfferIssuer tlv.OptionalRecordT[tlv.TlvType18, tlv.Blob] + + // OfferQuantityMax is the maximum number of items that can be requested + // in a single invoice, encoded as a tu64. A value of 0 means unlimited. + OfferQuantityMax tlv.OptionalRecordT[tlv.TlvType20, TUint64] + + // OfferIssuerID is the public key of the offer issuer. The codec + // parses the 33-byte SEC1 compressed point on decode, so a struct + // holding a key has already passed both the length and on-curve + // checks. + OfferIssuerID tlv.OptionalRecordT[tlv.TlvType22, *btcec.PublicKey] + + // decodedTLVs is the canonical TypeMap produced by decoding this offer. + // Handled types map to nil; unhandled types map to their value bytes. + // Encoding and validation both derive their view from this single field + // so they cannot drift apart, and so signed-range extras the decoder + // did not understand are re-emitted on encode and preserve offer_id. + decodedTLVs tlv.TypeMap +} + +var _ lnwire.PureTLVMessage = (*Offer)(nil) + +// AllRecords returns the canonical sorted record list for this offer, merging +// the typed records with any extra signed-range fields that the decoder +// preserved. +func (o *Offer) AllRecords() []tlv.Record { + return allRecordsFromTypeMap( + o.allRecordProducers(), o.decodedTLVs, + ) +} + +// allRecordProducers returns record producers for every set optional field, in +// declaration order. +func (o *Offer) allRecordProducers() []tlv.RecordProducer { + var p []tlv.RecordProducer + + lnwire.AddOpt(&p, o.OfferChains) + lnwire.AddOpt(&p, o.OfferMetadata) + lnwire.AddOpt(&p, o.OfferCurrency) + lnwire.AddOpt(&p, o.OfferAmount) + lnwire.AddOpt(&p, o.OfferDescription) + lnwire.AddOpt(&p, o.OfferFeatures) + lnwire.AddOpt(&p, o.OfferAbsoluteExpiry) + lnwire.AddOpt(&p, o.OfferPaths) + lnwire.AddOpt(&p, o.OfferIssuer) + lnwire.AddOpt(&p, o.OfferQuantityMax) + lnwire.AddOpt(&p, o.OfferIssuerID) + + return p +} + +// Encode serialises the offer into a canonical TLV byte stream. +func (o *Offer) Encode() ([]byte, error) { + if err := ValidateOfferWrite(o); err != nil { + return nil, fmt.Errorf("validate offer: %w", err) + } + + var buf bytes.Buffer + if err := lnwire.EncodePureTLVMessage(o, &buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// DecodeOffer parses a TLV byte stream into an Offer. Decoding is permissive — +// the spec writer requirements are not enforced here, so callers that need a +// valid offer must run ValidateOfferRead. Unknown TLVs are preserved on the +// returned offer so a later Encode can re-emit signed-range extras and keep +// offer_id stable. +func DecodeOffer(data []byte) (*Offer, error) { + var o Offer + + // Prepare zero-valued records for all optional fields so the TLV + // decoder can populate them. + chains := tlv.ZeroRecordT[tlv.TlvType2, ChainsRecord]() + metadata := tlv.ZeroRecordT[tlv.TlvType4, tlv.Blob]() + currency := tlv.ZeroRecordT[tlv.TlvType6, tlv.Blob]() + amount := tlv.ZeroRecordT[tlv.TlvType8, TUint64]() + desc := tlv.ZeroRecordT[tlv.TlvType10, tlv.Blob]() + features := tlv.ZeroRecordT[tlv.TlvType12, lnwire.RawFeatureVector]() + expiry := tlv.ZeroRecordT[tlv.TlvType14, TUint64]() + paths := tlv.ZeroRecordT[tlv.TlvType16, lnwire.BlindedPaths]() + issuer := tlv.ZeroRecordT[tlv.TlvType18, tlv.Blob]() + qtyMax := tlv.ZeroRecordT[tlv.TlvType20, TUint64]() + issuerID := tlv.ZeroRecordT[tlv.TlvType22, *btcec.PublicKey]() + + tm, err := decodeStream( + data, + chains.Record(), + metadata.Record(), + currency.Record(), + amount.Record(), + desc.Record(), + features.Record(), + expiry.Record(), + paths.Record(), + issuer.Record(), + qtyMax.Record(), + issuerID.Record(), + ) + if err != nil { + return nil, fmt.Errorf("decode offer: %w", err) + } + + lnwire.SetOptFromMap(tm, &o.OfferChains, chains) + lnwire.SetOptFromMap(tm, &o.OfferMetadata, metadata) + lnwire.SetOptFromMap(tm, &o.OfferCurrency, currency) + lnwire.SetOptFromMap(tm, &o.OfferAmount, amount) + lnwire.SetOptFromMap(tm, &o.OfferDescription, desc) + lnwire.SetOptFromMap(tm, &o.OfferFeatures, features) + lnwire.SetOptFromMap(tm, &o.OfferAbsoluteExpiry, expiry) + lnwire.SetOptFromMap(tm, &o.OfferPaths, paths) + lnwire.SetOptFromMap(tm, &o.OfferIssuer, issuer) + lnwire.SetOptFromMap(tm, &o.OfferQuantityMax, qtyMax) + lnwire.SetOptFromMap(tm, &o.OfferIssuerID, issuerID) + + o.decodedTLVs = tm + + return &o, nil +} diff --git a/bolt12/offer_test.go b/bolt12/offer_test.go new file mode 100644 index 00000000000..98dddda4248 --- /dev/null +++ b/bolt12/offer_test.go @@ -0,0 +1,49 @@ +package bolt12 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestOfferRoundTrip pins encode→decode→re-encode for an Offer with a +// representative subset of optional fields. A byte-identical re-encode is the +// invariant that keeps offer_id stable across the codec boundary. +func TestOfferRoundTrip(t *testing.T) { + t.Parallel() + + desc := tlv.Blob("coffee") + issuer := tlv.Blob("alice") + _, bobPub := bobKey() + + o := &Offer{ + OfferAmount: tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8](TUint64(1500)), + ), + OfferDescription: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10](desc), + ), + OfferIssuer: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType18](issuer), + ), + OfferIssuerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType22](bobPub), + ), + } + + encoded, err := o.Encode() + require.NoError(t, err) + require.NotEmpty(t, encoded) + + decoded, err := DecodeOffer(encoded) + require.NoError(t, err) + + require.Equal(t, TUint64(1500), decoded.OfferAmount.UnwrapOrFailV(t)) + require.Equal(t, desc, decoded.OfferDescription.UnwrapOrFailV(t)) + require.Equal(t, issuer, decoded.OfferIssuer.UnwrapOrFailV(t)) + + reencoded, err := decoded.Encode() + require.NoError(t, err) + require.Equal(t, encoded, reencoded) +} diff --git a/bolt12/pure_tlv.go b/bolt12/pure_tlv.go new file mode 100644 index 00000000000..9399eea26de --- /dev/null +++ b/bolt12/pure_tlv.go @@ -0,0 +1,52 @@ +package bolt12 + +import ( + "sort" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// bolt12InUnsignedRange reports whether a TLV type is excluded from the BOLT 12 +// Merkle tree. The spec reserves types 240-1000 for signature TLVs (the BIP-340 +// Schnorr signatures over the tree itself); every other allowed type sits in +// the signed range. +func bolt12InUnsignedRange(t tlv.Type) bool { + return t >= 240 && t <= 1000 +} + +// allRecordsFromTypeMap merges the typed-record producers with the signed-range +// subset of the supplied TypeMap (preserved unknown TLVs) and returns the +// canonical sorted record list. The signed-range subset is derived on demand +// from the same TypeMap that drives the validators, so the two views cannot +// drift apart. +func allRecordsFromTypeMap(producers []tlv.RecordProducer, + tm tlv.TypeMap) []tlv.Record { + + if len(tm) > 0 { + extra := lnwire.ExtraSignedFieldsFromTypeMapFn( + tm, bolt12InUnsignedRange, + ) + if len(extra) > 0 { + producers = append( + producers, lnwire.RecordsAsProducers( + tlv.MapToRecords(extra), + )..., + ) + } + } + + return lnwire.ProduceRecordsSorted(producers...) +} + +// sortedTypes returns the keys of tm in ascending order. Validators iterate the +// result for deterministic out-of-range and unknown-even error messages. +func sortedTypes(tm tlv.TypeMap) []tlv.Type { + out := make([]tlv.Type, 0, len(tm)) + for t := range tm { + out = append(out, t) + } + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + + return out +} diff --git a/bolt12/subtypes.go b/bolt12/subtypes.go new file mode 100644 index 00000000000..b9b041777e2 --- /dev/null +++ b/bolt12/subtypes.go @@ -0,0 +1,83 @@ +package bolt12 + +import ( + "errors" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ErrTooManyChains is returned when offer_chains declares more entries than +// maxOfferChains. +var ErrTooManyChains = errors.New("offer_chains exceeds maxOfferChains") + +const ( + // chainHashLen is the length of a chain hash (32 bytes). + chainHashLen = 32 + + // maxOfferChains caps decoded offer_chains entries. + maxOfferChains = 32 +) + +// ChainsRecord holds one or more chain hashes for the offer_chains field. +type ChainsRecord struct { + Chains [][chainHashLen]byte +} + +var _ tlv.RecordProducer = (*ChainsRecord)(nil) + +// Record returns a TLV record for ChainsRecord. +func (c *ChainsRecord) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, c, + func() uint64 { + return uint64(len(c.Chains)) * chainHashLen + }, + encodeChainsRecord, + decodeChainsRecord, + ) +} + +func encodeChainsRecord(w io.Writer, val any, _ *[8]byte) error { + c, ok := val.(*ChainsRecord) + if !ok { + return fmt.Errorf("expected *ChainsRecord, got %T", val) + } + + for _, chain := range c.Chains { + if _, err := w.Write(chain[:]); err != nil { + return err + } + } + + return nil +} + +// decodeChainsRecord caps the count at maxOfferChains to bound allocation. +func decodeChainsRecord(r io.Reader, val any, _ *[8]byte, l uint64) error { + c, ok := val.(*ChainsRecord) + if !ok { + return fmt.Errorf("expected *ChainsRecord, got %T", val) + } + + if l%chainHashLen != 0 { + return fmt.Errorf("chains length %d not a multiple of %d", l, + chainHashLen) + } + + numChains := l / chainHashLen + if numChains > maxOfferChains { + return fmt.Errorf("%w: %d > %d", ErrTooManyChains, numChains, + maxOfferChains) + } + + c.Chains = make([][chainHashLen]byte, numChains) + for i := range c.Chains { + if _, err := io.ReadFull(r, c.Chains[i][:]); err != nil { + return err + } + } + + return nil +} diff --git a/bolt12/subtypes_test.go b/bolt12/subtypes_test.go new file mode 100644 index 00000000000..ceecd14a084 --- /dev/null +++ b/bolt12/subtypes_test.go @@ -0,0 +1,154 @@ +package bolt12 + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecodeChainsRecord pins the chain-array decoder's structural rejections. +func TestDecodeChainsRecord(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + wantErr error + wantMsg string + }{ + { + name: "length not multiple of 32", + data: append( + bytes.Repeat( + []byte{0xaa}, chainHashLen, + ), + 187, + ), + wantMsg: "not a multiple of", + }, + { + name: "exceeds cap", + data: bytes.Repeat( + []byte{0x00}, (maxOfferChains+1)*chainHashLen, + ), + wantErr: ErrTooManyChains, + }, + } + + for _, tc := range tests { + t.Run( + tc.name, + func(t *testing.T) { + t.Parallel() + + var c ChainsRecord + err := decodeChainsRecord( + bytes.NewReader(tc.data), &c, + new([8]byte), + uint64( + len(tc.data), + ), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + + if tc.wantMsg != "" { + require.Contains( + t, err.Error(), tc.wantMsg, + ) + } + }, + ) + } +} + +// TestChainsRecordRoundTrip pins decode→re-encode against the BOLT 12 offer +// test vectors. +func TestChainsRecordRoundTrip(t *testing.T) { + t.Parallel() + + // bitcoinHash is the bitcoin mainnet genesis hash hex-decoded into a + // fixed array. Defined locally so the test does not depend on constants + // introduced by later commits. + bitcoinHashHex := "6fe28c0ab6f1b372c1a6a246ae63f74f931e8365" + + "e15a089c68d6190000000000" + + var bitcoinHash [chainHashLen]byte + bitcoinHashBytes, err := hex.DecodeString(bitcoinHashHex) + require.NoError(t, err) + copy(bitcoinHash[:], bitcoinHashBytes) + + tests := []struct { + name string + // hex is the on-wire bytes of the offer_chains TLV value + // (concatenated 32-byte chain hashes), copied from + // bolt12/offers-test.json. + hex string + wantLen int + wantHash [chainHashLen]byte + }{ + { + name: "single testnet chain", + hex: "43497fd7f826957108f4a30fd9cec3ae" + + "ba79972084e90ead01ea330900000000", + wantLen: 1, + }, + { + name: "single bitcoin chain", + hex: bitcoinHashHex, + wantLen: 1, + wantHash: bitcoinHash, + }, + { + name: "two chains liquidv1 then bitcoin", + hex: "1466275836220db2944ca059a3a10ef6fd2ea684b" + + "0688d2c379296888a206003" + bitcoinHashHex, + wantLen: 2, + // Second chain in the list is bitcoin mainnet. + wantHash: bitcoinHash, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(tc.hex) + require.NoError(t, err) + + var c ChainsRecord + err = decodeChainsRecord( + bytes.NewReader(data), &c, new([8]byte), + uint64( + len(data), + ), + ) + require.NoError(t, err) + require.Len(t, c.Chains, tc.wantLen) + + // Cross-check the canonical bitcoin chain hash where + // the row knows which slot it lives in. + var zero [chainHashLen]byte + if tc.wantHash != zero { + idx := tc.wantLen - 1 + require.Equal( + t, tc.wantHash, c.Chains[idx], + "bitcoin hash mismatch in slot %d", + idx, + ) + } + + var buf bytes.Buffer + require.NoError( + t, encodeChainsRecord(&buf, &c, new([8]byte)), + ) + + require.Equal(t, data, buf.Bytes()) + }) + } +} diff --git a/bolt12/tlv_types.go b/bolt12/tlv_types.go new file mode 100644 index 00000000000..b956477c112 --- /dev/null +++ b/bolt12/tlv_types.go @@ -0,0 +1,22 @@ +package bolt12 + +import ( + "github.com/lightningnetwork/lnd/tlv" +) + +// TUint64 is a uint64 that serializes using truncated encoding (tu64) +// as required by BOLT 12. Leading zero bytes are omitted. +type TUint64 uint64 + +// Record returns a TLV record using truncated uint64 encoding. +// +// NOTE: This implements the tlv.RecordProducer interface. +func (t *TUint64) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, (*uint64)(t), + func() uint64 { + return tlv.SizeTUint64(uint64(*t)) + }, + tlv.ETUint64, tlv.DTUint64, + ) +} diff --git a/bolt12/validate.go b/bolt12/validate.go new file mode 100644 index 00000000000..245ab5a1c04 --- /dev/null +++ b/bolt12/validate.go @@ -0,0 +1,372 @@ +package bolt12 + +import ( + "errors" + "fmt" + "slices" + "time" + "unicode/utf8" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "golang.org/x/text/currency" +) + +var ( + // ErrOutOfRangeType is returned when a TLV type falls outside the + // allowed offer ranges (1-79 and 1000000000-1999999999). + ErrOutOfRangeType = errors.New("TLV type outside allowed range") + + // ErrUnknownEvenType is returned when an unknown even TLV type is + // present in an allowed range. Per BOLT 1, even types are + // must-understand: if the reader does not recognise the type, it MUST + // reject the message rather than silently ignoring the field. + ErrUnknownEvenType = errors.New("unknown even TLV type") + + // ErrUnknownEvenFeature is returned when an unknown even feature + // bit is set. + ErrUnknownEvenFeature = errors.New("unknown even feature bit set") + + // ErrMissingDescription is returned when offer_amount is set but + // offer_description is absent. + ErrMissingDescription = errors.New( + "offer_amount set without offer_description", + ) + + // ErrCurrencyWithoutAmount is returned when offer_currency is set + // but offer_amount is absent. + ErrCurrencyWithoutAmount = errors.New( + "offer_currency set without offer_amount", + ) + + // ErrEmptyBlindedPaths is returned when a blinded paths field is + // present on a BOLT 12 message but its list of paths is empty. The + // spec writer requirements treat "present" as implying at least one + // usable path. + ErrEmptyBlindedPaths = errors.New("blinded paths field present but " + + "empty") + + // ErrNoIssuerIdentity is returned when neither offer_issuer_id + // nor offer_paths is set. + ErrNoIssuerIdentity = errors.New( + "neither offer_issuer_id nor offer_paths set", + ) + + // ErrOfferExpired is returned when the current time is after + // offer_absolute_expiry. + ErrOfferExpired = errors.New("offer has expired") + + // ErrEmptyChains is returned when offer_chains is present but + // contains no entries. + ErrEmptyChains = errors.New( + "offer_chains present but empty", + ) + + // ErrUnsupportedChain is returned when offer_chains does not + // contain our active chain. + ErrUnsupportedChain = errors.New( + "offer does not support our chain", + ) + + // ErrInvalidUTF8 is returned when a UTF-8 field contains invalid + // sequences. + ErrInvalidUTF8 = errors.New("invalid UTF-8") + + // ErrInvalidCurrency is returned when offer_currency is not a + // valid ISO 4217 code. + ErrInvalidCurrency = errors.New("invalid offer_currency") +) + +// offerAllowedRange returns true if the TLV type falls within the allowed +// ranges for offer messages: 1-79 and 1000000000-1999999999. +func offerAllowedRange(typ tlv.Type) bool { + return (typ >= 1 && typ <= 79) || + (typ >= 1000000000 && typ <= 1999999999) +} + +// isKnownOfferTLVType returns true for TLV types that are defined in the offer +// spec (even types 2-22). +func isKnownOfferTLVType(typ tlv.Type) bool { + switch typ { + case 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22: + return true + default: + return false + } +} + +// ValidateOfferRead validates an offer per the BOLT 12 offer reader +// requirements. The now parameter is used for expiry checks and can be +// overridden in tests. activeChain is required: per spec, absent offer_chains +// defaults to Bitcoin mainnet, and the reader must reject offers that do not +// list a chain it operates on. Pass the genesis hash of the chain the receiver +// is willing to settle on. +func ValidateOfferRead(o *Offer, now time.Time, activeChain [32]byte) error { + // Check TLV types are in allowed range and that unknown even types are + // rejected (even = must-understand). + for _, t := range sortedTypes(o.decodedTLVs) { + if !offerAllowedRange(t) { + return fmt.Errorf("%w: type %d", ErrOutOfRangeType, t) + } + + if !isKnownOfferTLVType(t) && t%2 == 0 { + return fmt.Errorf("%w: type %d", ErrUnknownEvenType, t) + } + } + + // Check for unknown even feature bits. + if err := checkFeatures(o.OfferFeatures); err != nil { + return err + } + + // offer_chains present but empty. + var chainsEmpty bool + o.OfferChains.WhenSome( + func(r tlv.RecordT[tlv.TlvType2, ChainsRecord]) { + if len(r.Val.Chains) == 0 { + chainsEmpty = true + } + }, + ) + if chainsEmpty { + return ErrEmptyChains + } + + // Validate the offer's chain against the active chain. An absent + // offer_chains TLV means "Bitcoin mainnet" per spec, normalised by + // getOfferChains. + offerChains := getOfferChains(o) + found := slices.Contains(offerChains, activeChain) + if !found { + return ErrUnsupportedChain + } + + // offer_amount set requires offer_description. + hasAmount := o.OfferAmount.IsSome() + if hasAmount && !o.OfferDescription.IsSome() { + return ErrMissingDescription + } + + // offer_currency requires offer_amount. + if o.OfferCurrency.IsSome() && !hasAmount { + return ErrCurrencyWithoutAmount + } + + // Must have either offer_issuer_id or offer_paths. + if !o.OfferIssuerID.IsSome() && !o.OfferPaths.IsSome() { + return ErrNoIssuerIdentity + } + + // Check blinded paths have at least one hop. + if err := checkBlindedPaths(o.OfferPaths); err != nil { + return err + } + + // Expiry check. A present-but-zero offer_absolute_expiry historically + // meant "no expiry" but that conflicts with the spec (zero is a valid + // past timestamp); treat it as already expired rather than ambiguous so + // a misuse fails closed. + var ( + expiry uint64 + hasExpiry bool + ) + o.OfferAbsoluteExpiry.WhenSome( + func(r tlv.RecordT[tlv.TlvType14, TUint64]) { + expiry = uint64(r.Val) + hasExpiry = true + }, + ) + if hasExpiry && uint64(now.Unix()) >= expiry { + return ErrOfferExpired + } + + // Validate UTF-8 fields. + if err := checkUTF8(o.OfferCurrency, "offer_currency"); err != nil { + return err + } + + if err := checkUTF8( + o.OfferDescription, "offer_description", + ); err != nil { + return err + } + + if err := checkUTF8(o.OfferIssuer, "offer_issuer"); err != nil { + return err + } + + if err := checkISO4217(o.OfferCurrency); err != nil { + return err + } + + return nil +} + +// bitcoinMainnetGenesisHash is the genesis hash for Bitcoin mainnet, used as +// the default when offer_chains is absent per the spec. +var bitcoinMainnetGenesisHash = [32]byte(*chaincfg.MainNetParams.GenesisHash) + +// getOfferChains returns the chains an offer is valid for. If offer_chains is +// absent, the spec defaults to Bitcoin mainnet. +func getOfferChains(o *Offer) [][32]byte { + chains := fn.MapOptionZ( + o.OfferChains.ValOpt(), + func(r ChainsRecord) [][32]byte { return r.Chains }, + ) + + if len(chains) == 0 { + chains = [][32]byte{bitcoinMainnetGenesisHash} + } + + return chains +} + +// ValidateOfferWrite validates an offer per the BOLT 12 offer writer +// requirements. +func ValidateOfferWrite(o *Offer) error { + // Writer MUST NOT set TLV fields outside allowed ranges. This check + // catches a decoded-then-mutated offer: a freshly-built struct has no + // decodedTLVs (Decode is the only writer of that field). The typed + // field set already excludes out-of-range types by construction, so a + // freshly-built offer cannot violate the range rule in the first place. + for _, t := range sortedTypes(o.decodedTLVs) { + if !offerAllowedRange(t) { + return fmt.Errorf("%w: type %d", + ErrOutOfRangeType, t) + } + } + + // offer_amount requires offer_description. + if o.OfferAmount.IsSome() && !o.OfferDescription.IsSome() { + return ErrMissingDescription + } + + // offer_currency requires offer_amount. + if o.OfferCurrency.IsSome() && !o.OfferAmount.IsSome() { + return ErrCurrencyWithoutAmount + } + + // Without offer_paths, MUST set offer_issuer_id. + if !o.OfferPaths.IsSome() && !o.OfferIssuerID.IsSome() { + return ErrNoIssuerIdentity + } + + // Defense in depth: writer-side mirrors of reader rejections for + // present-but-empty offer_chains and offer_paths. + var chainsEmpty bool + o.OfferChains.WhenSome( + func(r tlv.RecordT[tlv.TlvType2, ChainsRecord]) { + if len(r.Val.Chains) == 0 { + chainsEmpty = true + } + }, + ) + if chainsEmpty { + return ErrEmptyChains + } + + if err := checkBlindedPaths(o.OfferPaths); err != nil { + return err + } + + // Defense in depth: writer-side mirrors of the reader UTF-8 checks + // for offer_currency, offer_description, and offer_issuer. + if err := checkUTF8(o.OfferCurrency, "offer_currency"); err != nil { + return err + } + + if err := checkUTF8( + o.OfferDescription, "offer_description", + ); err != nil { + return err + } + + if err := checkUTF8(o.OfferIssuer, "offer_issuer"); err != nil { + return err + } + + if err := checkISO4217(o.OfferCurrency); err != nil { + return err + } + + return nil +} + +// checkISO4217 verifies that offer_currency, if set, parses as an ISO 4217 +// code. The upstream parser is case-insensitive and rejects both malformed and +// unrecognised codes. +func checkISO4217[T tlv.TlvType](opt tlv.OptionalRecordT[T, tlv.Blob]) error { + return fn.MapOptionZ(opt.ValOpt(), func(data tlv.Blob) error { + if _, err := currency.ParseISO(string(data)); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidCurrency, err) + } + + return nil + }) +} + +// checkFeatures rejects any unknown even (must-understand) feature bit. +func checkFeatures[T tlv.TlvType]( + opt tlv.OptionalRecordT[T, lnwire.RawFeatureVector]) error { + + return fn.MapOptionZ( + opt.ValOpt(), + func(fv lnwire.RawFeatureVector) error { + // SerializeSize bounds the largest set bit, so walking + // up to SerializeSize*8 covers every bit the wire can + // express. + size := fv.SerializeSize() * 8 + for i := 0; i < size; i++ { + bit := lnwire.FeatureBit(i) + if !fv.IsSet(bit) { + continue + } + if bit%2 == 0 { + return fmt.Errorf("%w: bit %d", + ErrUnknownEvenFeature, bit) + } + } + + return nil + }, + ) +} + +// checkBlindedPaths walks each path in a blinded paths field and rejects empty +// Paths slices and paths with zero hops. +func checkBlindedPaths[T tlv.TlvType]( + opt tlv.OptionalRecordT[T, lnwire.BlindedPaths]) error { + + return fn.MapOptionZ( + opt.ValOpt(), + func(paths lnwire.BlindedPaths) error { + if len(paths.Paths) == 0 { + return ErrEmptyBlindedPaths + } + + for i, p := range paths.Paths { + if len(p.Hops) == 0 { + return fmt.Errorf("%w: path %d", + lnwire.ErrEmptyBlindedPath, i) + } + } + + return nil + }, + ) +} + +// checkUTF8 validates that a blob field contains valid UTF-8. +func checkUTF8[T tlv.TlvType](opt tlv.OptionalRecordT[T, tlv.Blob], + name string) error { + + return fn.MapOptionZ(opt.ValOpt(), func(data tlv.Blob) error { + if !utf8.Valid(data) { + return fmt.Errorf("%w: %s", ErrInvalidUTF8, name) + } + + return nil + }) +} diff --git a/bolt12/validate_test.go b/bolt12/validate_test.go new file mode 100644 index 00000000000..b3183352327 --- /dev/null +++ b/bolt12/validate_test.go @@ -0,0 +1,371 @@ +package bolt12 + +import ( + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// validBobOffer is the spec-minimal happy-path offer that each table row +// mutates to isolate the rule under test. +func validBobOffer(t *testing.T) *Offer { + t.Helper() + + _, pub := bobKey() + + return &Offer{ + OfferIssuerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType22](pub), + ), + } +} + +// TestValidateOfferWrite pins the BOLT 12 writer-side MUSTs that the codec can +// enforce. +func TestValidateOfferWrite(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*Offer) + wantErr error + }{ + { + name: "happy path with issuer_id only", + mutate: func(*Offer) {}, + wantErr: nil, + }, + { + name: "amount without description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(1000), + ), + ) + }, + wantErr: ErrMissingDescription, + }, + { + name: "currency without amount", + mutate: func(o *Offer) { + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("USD")), + ) + }, + wantErr: ErrCurrencyWithoutAmount, + }, + { + name: "no issuer or paths", + mutate: func(o *Offer) { + o.OfferIssuerID = tlv.OptionalRecordT[ + tlv.TlvType22, *btcec.PublicKey]{} + }, + wantErr: ErrNoIssuerIdentity, + }, + { + name: "empty offer_chains", + mutate: func(o *Offer) { + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{Chains: nil}, + ), + ) + }, + wantErr: ErrEmptyChains, + }, + { + name: "currency wrong length", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("US")), + ) + }, + wantErr: ErrInvalidCurrency, + }, + { + name: "currency unknown ISO 4217 code", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("ZZZ")), + ) + }, + wantErr: ErrInvalidCurrency, + }, + { + // Pins the docstring claim that ValidateOfferWrite's + // offerAllowedRange loop exists to catch a + // decoded-then-mutated offer with an out-of-range TLV + // resurfacing via decodedTLVs. + name: "out-of-range TLV in decoded extras", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{200: nil} + }, + wantErr: ErrOutOfRangeType, + }, + { + name: "empty blinded paths list", + mutate: func(o *Offer) { + o.OfferPaths = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{Paths: nil}, + ), + ) + }, + wantErr: ErrEmptyBlindedPaths, + }, + { + name: "blinded path with zero hops", + mutate: func(o *Offer) { + _, intro := aliceKey() + _, blinding := bobKey() + pk := lnwire.PubkeyIntro{Pubkey: intro} + o.OfferPaths = tlv.SomeRecordT( + //nolint:ll + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{ + Paths: []lnwire.BlindedPath{{ + IntroductionNode: pk, + BlindingPoint: blinding, + Hops: nil, + }}, + }, + ), + ) + }, + wantErr: lnwire.ErrEmptyBlindedPath, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + o := validBobOffer(t) + tc.mutate(o) + + err := ValidateOfferWrite(o) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// TestValidateOfferRead pins the BOLT 12 reader-side MUSTs so a malformed or +// unsafe offer is rejected before any invoice request reaches the wire. +func TestValidateOfferRead(t *testing.T) { + t.Parallel() + + now := time.Unix(1_700_000_000, 0) + + var nonBitcoin [32]byte + nonBitcoin[0] = 0x01 + + tests := []struct { + name string + mutate func(*Offer) + activeChain [32]byte + wantErr error + }{ + { + name: "happy path on bitcoin mainnet", + mutate: func(*Offer) {}, + activeChain: bitcoinMainnetGenesisHash, + }, + { + name: "out-of-range TLV in decoded extras", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{200: nil} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrOutOfRangeType, + }, + { + name: "unknown even TLV type in range rejected", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{24: nil} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnknownEvenType, + }, + { + name: "unknown even feature bit rejected", + mutate: func(o *Offer) { + o.OfferFeatures = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12]( + *lnwire.NewRawFeatureVector(0), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnknownEvenFeature, + }, + { + name: "unknown odd feature bit ignored", + mutate: func(o *Offer) { + o.OfferFeatures = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12]( + *lnwire.NewRawFeatureVector(1), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: nil, + }, + { + name: "non-bitcoin chain rejected when " + + "offer_chains absent", + mutate: func(*Offer) {}, + activeChain: nonBitcoin, + wantErr: ErrUnsupportedChain, + }, + { + name: "explicit chain list missing active chain", + mutate: func(o *Offer) { + var c [32]byte + c[0] = 0xaa + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{ + Chains: [][32]byte{c}, + }, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnsupportedChain, + }, + { + name: "empty offer_chains list", + mutate: func(o *Offer) { + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{Chains: nil}, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrEmptyChains, + }, + { + name: "amount without description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(1000), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrMissingDescription, + }, + { + name: "currency without amount", + mutate: func(o *Offer) { + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("USD")), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrCurrencyWithoutAmount, + }, + { + name: "missing issuer and paths", + mutate: func(o *Offer) { + o.OfferIssuerID = tlv.OptionalRecordT[ + tlv.TlvType22, *btcec.PublicKey]{} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrNoIssuerIdentity, + }, + { + name: "blinded path with zero hops", + mutate: func(o *Offer) { + _, intro := aliceKey() + _, blinding := bobKey() + pk := lnwire.PubkeyIntro{Pubkey: intro} + o.OfferPaths = tlv.SomeRecordT( + //nolint:ll + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{ + Paths: []lnwire.BlindedPath{{ + IntroductionNode: pk, + BlindingPoint: blinding, + Hops: nil, + }}, + }, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: lnwire.ErrEmptyBlindedPath, + }, + { + name: "expired offer", + mutate: func(o *Offer) { + expiry := uint64(now.Unix()) - 1 + o.OfferAbsoluteExpiry = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType14]( + TUint64(expiry), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrOfferExpired, + }, + { + name: "currency wrong length", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("US")), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrInvalidCurrency, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + o := validBobOffer(t) + tc.mutate(o) + + err := ValidateOfferRead(o, now, tc.activeChain) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// addAmountAndDescription satisfies the dependency rules so currency-shape rows +// are not short-circuited before the ISO 4217 check runs. +func addAmountAndDescription(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8](TUint64(1000)), + ) + o.OfferDescription = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10, tlv.Blob]( + tlv.Blob("a tip"), + ), + ) +} diff --git a/go.mod b/go.mod index 8457d8a8d2a..5822672ebca 100644 --- a/go.mod +++ b/go.mod @@ -185,7 +185,7 @@ require ( golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.32.0 golang.org/x/tools v0.39.0 // indirect google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/lnwire/blinded_path.go b/lnwire/blinded_path.go new file mode 100644 index 00000000000..cbf64372177 --- /dev/null +++ b/lnwire/blinded_path.go @@ -0,0 +1,325 @@ +package lnwire + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // ErrInvalidIntroNode is returned when a blinded path's introduction + // node discriminator is not one of the spec-defined values. + ErrInvalidIntroNode = errors.New("invalid blinded-path introduction " + + "node discriminator") + + // ErrEmptyBlindedPath is returned when a blinded path has zero hops. + ErrEmptyBlindedPath = errors.New("blinded path with zero hops") +) + +// BlindedPath holds the introduction node, blinding point, and encrypted hops +// of a single blinded path. +type BlindedPath struct { + // IntroductionNode is the variant-defined introduction node for this + // blinded path. + IntroductionNode IntroductionNode + + // BlindingPoint is the blinding point for this path, used to derive the + // blinded node IDs and encrypt the hop payloads. + BlindingPoint *btcec.PublicKey + + // Hops is the ordered list of blinded hops in this path. + Hops []BlindedHop +} + +// BlindedPaths holds one or more blinded paths. +type BlindedPaths struct { + Paths []BlindedPath +} + +// BlindedHop represents a single hop in a blinded path. +type BlindedHop struct { + // BlindedNodeID is the blinded public key for this hop. + BlindedNodeID *btcec.PublicKey + + // EncryptedData is the encrypted payload for this hop. + EncryptedData []byte +} + +var ( + _ tlv.RecordProducer = (*BlindedPath)(nil) + _ tlv.RecordProducer = (*BlindedPaths)(nil) +) + +// Record returns a TLV record for a single BlindedPath at the BOLT 4 reply_path +// TLV type. Used directly by OnionMessagePayload's reply_path encoding. +func (p *BlindedPath) Record() tlv.Record { + return tlv.MakeDynamicRecord( + replyPathType, p, + func() uint64 { + return blindedPathSize(p) + }, + encodeBlindedPath, + decodeBlindedPath, + ) +} + +// blindedPathSize returns the on-wire size of a single BlindedPath. +func blindedPathSize(p *BlindedPath) uint64 { + var introLen uint64 + if p.IntroductionNode != nil { + introLen = p.IntroductionNode.encodedLen() + } + + // introduction_node (variant-defined) + blinding_point (33) + + // num_hops (1). + size := introLen + pubKeyLen + 1 + for _, h := range p.Hops { + // blinded_node_id (33) + enclen (2) + enc_data. + size += pubKeyLen + 2 + uint64(len(h.EncryptedData)) + } + + return size +} + +// encodeBlindedPath writes a single blinded path. No bytes are written if the +// path fails validation. +func encodeBlindedPath(w io.Writer, val any, buf *[8]byte) error { + p, ok := val.(*BlindedPath) + if !ok { + return fmt.Errorf("expected *BlindedPath, got %T", val) + } + + return writeBlindedPath(w, p, buf) +} + +// writeBlindedPath validates the path and writes a single blinded path to w. +func writeBlindedPath(w io.Writer, p *BlindedPath, buf *[8]byte) error { + if p.IntroductionNode == nil { + return fmt.Errorf("nil intro node") + } + + if err := p.IntroductionNode.validate(); err != nil { + return err + } + + if p.BlindingPoint == nil { + return fmt.Errorf("nil blinding point") + } + + if len(p.Hops) == 0 { + return ErrEmptyBlindedPath + } + if len(p.Hops) > maxBlindedPathHops { + return fmt.Errorf("%d hops exceeds limit %d", len(p.Hops), + maxBlindedPathHops) + } + + if err := p.IntroductionNode.encode(w); err != nil { + return err + } + blindingBytes := p.BlindingPoint.SerializeCompressed() + if _, err := w.Write(blindingBytes); err != nil { + return err + } + + buf[0] = uint8(len(p.Hops)) + if _, err := w.Write(buf[:1]); err != nil { + return err + } + + for hIdx := range p.Hops { + if err := writeBlindedHop(w, &p.Hops[hIdx], buf); err != nil { + return fmt.Errorf("hop %d: %w", hIdx, err) + } + } + + return nil +} + +// decodeBlindedPath reads a single blinded path framed at the TLV-value level. +func decodeBlindedPath(r io.Reader, val any, buf *[8]byte, l uint64) error { + p, ok := val.(*BlindedPath) + if !ok { + return fmt.Errorf("expected *BlindedPath, got %T", val) + } + + lr := &io.LimitedReader{R: r, N: int64(l)} + + if err := readBlindedPath(lr, p, buf); err != nil { + return err + } + + if lr.N != 0 { + return fmt.Errorf("trailing %d bytes after blinded path", lr.N) + } + + return nil +} + +// readBlindedPath decodes a single blinded path from lr. +func readBlindedPath(lr *io.LimitedReader, p *BlindedPath, + buf *[8]byte) error { + + intro, err := decodeIntroductionNode(lr, buf) + if err != nil { + return err + } + p.IntroductionNode = intro + + var blindingBytes [pubKeyLen]byte + if _, err := io.ReadFull(lr, blindingBytes[:]); err != nil { + return fmt.Errorf("read blinding point: %w", err) + } + blinding, err := btcec.ParsePubKey(blindingBytes[:]) + if err != nil { + return fmt.Errorf("blinding point: %w", err) + } + p.BlindingPoint = blinding + + if _, err := io.ReadFull(lr, buf[:1]); err != nil { + return fmt.Errorf("read num_hops: %w", err) + } + numHops := int(buf[0]) + if numHops == 0 { + return ErrEmptyBlindedPath + } + + if int64(numHops)*minBlindedHopBytes > lr.N { + return fmt.Errorf("num_hops %d exceeds remaining %d bytes", + numHops, lr.N) + } + + p.Hops = make([]BlindedHop, numHops) + for i := range p.Hops { + if err := readBlindedHop(lr, &p.Hops[i], buf); err != nil { + return err + } + } + + return nil +} + +// Record returns a TLV record for BlindedPaths. +func (bp *BlindedPaths) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, bp, + func() uint64 { + return blindedPathsSize(bp) + }, + encodeBlindedPaths, + decodeBlindedPaths, + ) +} + +// blindedPathsSize returns the on-wire size of multiple BlindedPaths. +func blindedPathsSize(bp *BlindedPaths) uint64 { + var size uint64 + for i := range bp.Paths { + size += blindedPathSize(&bp.Paths[i]) + } + + return size +} + +// encodeBlindedPaths writes the multi-path TLV value as concatenated paths. +// Fails closed under the same conditions as encodeBlindedPath. +func encodeBlindedPaths(w io.Writer, val any, buf *[8]byte) error { + bp, ok := val.(*BlindedPaths) + if !ok { + return fmt.Errorf("expected *BlindedPaths, got %T", val) + } + + for pIdx := range bp.Paths { + err := writeBlindedPath(w, &bp.Paths[pIdx], buf) + if err != nil { + return fmt.Errorf("blinded path %d: %w", pIdx, err) + } + } + + return nil +} + +// decodeBlindedPaths reads concatenated blinded paths. The LimitedReader gates +// each variable-length subfield against the bytes still on the wire, so an +// oversize hop count cannot force a large allocation before io.ReadFull +// notices the bytes are absent. +func decodeBlindedPaths(r io.Reader, val any, buf *[8]byte, l uint64) error { + bp, ok := val.(*BlindedPaths) + if !ok { + return fmt.Errorf("expected *BlindedPaths, got %T", val) + } + + lr := &io.LimitedReader{R: r, N: int64(l)} + + for lr.N > 0 { + var p BlindedPath + if err := readBlindedPath(lr, &p, buf); err != nil { + return err + } + bp.Paths = append(bp.Paths, p) + } + + return nil +} + +// writeBlindedHop emits BlindedNodeID + enclen + encrypted data. The size cap +// is checked first so no bytes hit the writer on rejection. +func writeBlindedHop(w io.Writer, h *BlindedHop, buf *[8]byte) error { + if h.BlindedNodeID == nil { + return fmt.Errorf("nil blinded node id") + } + + if len(h.EncryptedData) > maxEncryptedDataLen { + return fmt.Errorf("encrypted data %d exceeds limit %d", + len(h.EncryptedData), maxEncryptedDataLen) + } + + nodeIDBytes := h.BlindedNodeID.SerializeCompressed() + if _, err := w.Write(nodeIDBytes); err != nil { + return err + } + + binary.BigEndian.PutUint16(buf[:2], uint16(len(h.EncryptedData))) + if _, err := w.Write(buf[:2]); err != nil { + return err + } + if _, err := w.Write(h.EncryptedData); err != nil { + return err + } + + return nil +} + +// readBlindedHop decodes a single blinded hop. The enclen guard against lr.N +// bounds the EncryptedData allocation. +func readBlindedHop(lr *io.LimitedReader, h *BlindedHop, buf *[8]byte) error { + var nodeBytes [pubKeyLen]byte + if _, err := io.ReadFull(lr, nodeBytes[:]); err != nil { + return fmt.Errorf("read blinded node: %w", err) + } + node, err := btcec.ParsePubKey(nodeBytes[:]) + if err != nil { + return fmt.Errorf("blinded node id: %w", err) + } + h.BlindedNodeID = node + + if _, err := io.ReadFull(lr, buf[:2]); err != nil { + return fmt.Errorf("read enclen: %w", err) + } + encLen := binary.BigEndian.Uint16(buf[:2]) + if int64(encLen) > lr.N { + return fmt.Errorf("enclen %d exceeds remaining %d", encLen, + lr.N) + } + + h.EncryptedData = make([]byte, encLen) + if _, err := io.ReadFull(lr, h.EncryptedData); err != nil { + return fmt.Errorf("read encrypted data: %w", err) + } + + return nil +} diff --git a/lnwire/blinded_path_test.go b/lnwire/blinded_path_test.go new file mode 100644 index 00000000000..207e08a45a0 --- /dev/null +++ b/lnwire/blinded_path_test.go @@ -0,0 +1,414 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" +) + +// validPubkeyIntro returns an on-curve PubkeyIntro plus the matching +// *btcec.PublicKey for assertions. +func validPubkeyIntro(t *testing.T) (PubkeyIntro, *btcec.PublicKey) { + t.Helper() + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + pub := priv.PubKey() + + return PubkeyIntro{Pubkey: pub}, pub +} + +// validBlindingPoint returns an on-curve pubkey suitable for use as a +// BlindingPoint or BlindedNodeID in tests. +func validBlindingPoint(t *testing.T) *btcec.PublicKey { + t.Helper() + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return priv.PubKey() +} + +// oversizeEncDataPaths returns a BlindedPaths with a single hop whose +// EncryptedData is one byte over the wire-format limit, used by the +// encode-rejects test. +func oversizeEncDataPaths(t *testing.T, intro IntroductionNode) *BlindedPaths { + t.Helper() + + return &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: intro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{{ + BlindedNodeID: validBlindingPoint(t), + EncryptedData: make( + []byte, maxEncryptedDataLen+1, + ), + }}, + }}, + } +} + +// TestBlindedPathRoundTrip pins encode→decode parity across both +// IntroductionNode variants and across single- and multi-path framings, so +// concrete variant types survive the round-trip with byte-identical output. +func TestBlindedPathRoundTrip(t *testing.T) { + t.Parallel() + + pubkeyIntro, _ := validPubkeyIntro(t) + sciddirIntro := SciddirIntro{ + Direction: 0x01, + SCID: [8]byte{ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }, + } + + hop := func(payload byte) BlindedHop { + return BlindedHop{ + BlindedNodeID: validBlindingPoint(t), + EncryptedData: []byte{payload, payload ^ 0xff}, + } + } + + pubkeyPath := BlindedPath{ + IntroductionNode: pubkeyIntro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{ + hop(0xde), + hop(0xad), + }, + } + sciddirPath := BlindedPath{ + IntroductionNode: sciddirIntro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{hop(0xbe)}, + } + + tests := []struct { + name string + paths []BlindedPath + }{ + { + name: "single pubkey path", + paths: []BlindedPath{pubkeyPath}, + }, + { + name: "single sciddir path", + paths: []BlindedPath{sciddirPath}, + }, + { + name: "mixed multi-path", + paths: []BlindedPath{pubkeyPath, sciddirPath}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + bp := &BlindedPaths{Paths: tc.paths} + + var buf bytes.Buffer + require.NoError(t, encodeBlindedPaths( + &buf, bp, new([8]byte), + )) + + var decoded BlindedPaths + err := decodeBlindedPaths( + bytes.NewReader(buf.Bytes()), &decoded, + new([8]byte), uint64(buf.Len()), + ) + require.NoError(t, err) + require.Equal(t, bp.Paths, decoded.Paths) + + // Single-path framing must round-trip too: the + // reply_path TLV carries one BlindedPath, not a list. + if len(tc.paths) == 1 { + var single bytes.Buffer + require.NoError(t, encodeBlindedPath( + &single, &tc.paths[0], new([8]byte), + )) + + var decodedSingle BlindedPath + err := decodeBlindedPath( + bytes.NewReader(single.Bytes()), + &decodedSingle, new([8]byte), + uint64(single.Len()), + ) + require.NoError(t, err) + require.Equal( + t, tc.paths[0], decodedSingle, + ) + } + }) + } +} + +// TestDecodeBlindedPathsRejects covers every malformed-input branch the +// decoder must refuse: bad discriminators, allocation bombs, and short reads. +// The catch-all is that the decoder never allocates more memory than the +// remaining wire bytes can justify. +func TestDecodeBlindedPathsRejects(t *testing.T) { + t.Parallel() + + // validKey is a 33-byte compressed SEC1 pubkey that the on-curve + // decoder accepts; reused as both intro pubkey and blinding point so + // the tests can exercise post-pubkey decode branches. + validKey := validBlindingPoint(t).SerializeCompressed() + + // hopAllocOverflow declares num_hops=255 with no hop payload. Without + // the remaining-bytes guard the decoder would make([]BlindedHop, 255) + // before io.ReadFull notices the bytes are absent. + hopAllocOverflow := func() []byte { + out := make([]byte, 0, 67) + out = append(out, validKey...) + out = append(out, validKey...) + out = append(out, 0xff) + + return out + } + + // enclenOverflow declares enclen=65535 on a hop with no payload. The + // guard against lr.N must reject before make([]byte, 65535). + enclenOverflow := func() []byte { + out := make([]byte, 0, 70) + out = append(out, validKey...) + out = append(out, validKey...) + out = append(out, 0x01) + out = append(out, validKey...) + out = append(out, 0xff, 0xff) + + return out + } + + // shortIntroPubkey truncates after the discriminator + 5 of 33 bytes + // of intro pubkey, exercising io.ReadFull's short-read error. + shortIntroPubkey := func() []byte { + return append([]byte{0x02}, bytes.Repeat([]byte{0x00}, 5)...) + } + + // shortBlindingPoint truncates after a full intro pubkey plus 5 of the + // 33 blinding-point bytes, exercising io.ReadFull's short-read path + // past the discriminator. + shortBlindingPoint := func() []byte { + out := make([]byte, 0, pubKeyLen+5) + out = append(out, validKey...) + out = append(out, bytes.Repeat([]byte{0x00}, 5)...) + + return out + } + + tests := []struct { + name string + data []byte + wantErr error + wantMsg []string + }{ + { + name: "invalid discriminator 0x04", + data: []byte{0x04}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "invalid discriminator 0x05", + data: []byte{0x05}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "invalid discriminator 0xff", + data: []byte{0xff}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "hop alloc overflow", + data: hopAllocOverflow(), + wantMsg: []string{"num_hops", "exceeds remaining"}, + }, + { + name: "enclen alloc overflow", + data: enclenOverflow(), + wantMsg: []string{"enclen", "exceeds remaining"}, + }, + { + name: "short intro pubkey", + data: shortIntroPubkey(), + wantMsg: []string{"read intro pubkey"}, + }, + { + name: "short blinding point", + data: shortBlindingPoint(), + wantMsg: []string{"read blinding point"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var bp BlindedPaths + err := decodeBlindedPaths( + bytes.NewReader(tc.data), &bp, new([8]byte), + uint64(len(tc.data)), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + for _, msg := range tc.wantMsg { + require.Contains(t, err.Error(), msg) + } + }) + } +} + +// TestEncodeBlindedPathsRejects pins the encoder's fail-closed guards. Any +// case here must not emit bytes — invalid input cannot be retracted from the +// wire once flushed. +func TestEncodeBlindedPathsRejects(t *testing.T) { + t.Parallel() + + validIntro, _ := validPubkeyIntro(t) + validHop := BlindedHop{BlindedNodeID: validBlindingPoint(t)} + + tests := []struct { + name string + paths *BlindedPaths + wantErr error + wantMsg []string + wantNoWrite bool + }{ + { + name: "nil intro", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantMsg: []string{"nil intro node"}, + wantNoWrite: true, + }, + { + name: "nil pubkey in PubkeyIntro", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: PubkeyIntro{}, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{ + validHop, + }, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "invalid sciddir direction 0x02", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: SciddirIntro{ + Direction: 0x02, + }, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "invalid sciddir direction 0xff", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: SciddirIntro{ + Direction: 0xff, + }, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "nil blinding point", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + Hops: []BlindedHop{ + validHop, + }, + }}, + }, + wantMsg: []string{"nil blinding point"}, + wantNoWrite: true, + }, + { + name: "zero hops", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + BlindingPoint: validBlindingPoint(t), + Hops: nil, + }}, + }, + wantErr: ErrEmptyBlindedPath, + wantNoWrite: true, + }, + { + name: "hop overflow", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + BlindingPoint: validBlindingPoint(t), + Hops: func() []BlindedHop { + hops := make([]BlindedHop, + maxBlindedPathHops+1) + pub := validBlindingPoint(t) + for i := range hops { + // Write to hop. + h := &hops[i] + h.BlindedNodeID = pub + } + + return hops + }(), + }}, + }, + wantMsg: []string{"exceeds limit"}, + wantNoWrite: true, + }, + { + name: "oversize encrypted data", + paths: oversizeEncDataPaths(t, validIntro), + wantMsg: []string{"exceeds limit"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := encodeBlindedPaths( + &buf, tc.paths, new([8]byte), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + for _, msg := range tc.wantMsg { + require.Contains(t, err.Error(), msg) + } + if tc.wantNoWrite { + require.Equal(t, 0, buf.Len(), + "encoder wrote bytes on fail-closed "+ + "path") + } + }) + } +} diff --git a/lnwire/bounds.go b/lnwire/bounds.go new file mode 100644 index 00000000000..8cb79bb4f98 --- /dev/null +++ b/lnwire/bounds.go @@ -0,0 +1,35 @@ +package lnwire + +import ( + "math" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// BOLT 4 blinded-path field bounds. Each constant matches the format ceiling +// imposed by the spec encoding (uint8 num_hops, uint16 enclen). +const ( + // pubKeyLen aliases the upstream compressed-pubkey length for shorter + // usage in this package. + pubKeyLen = btcec.PubKeyBytesLenCompressed + + // sciddirLen is the on-wire length of a sciddir introduction node + // (1-byte direction + 8-byte SCID). + sciddirLen = 9 + + // scidLen is the byte length of a short channel ID. + scidLen = 8 + + // maxBlindedPathHops bounds the number of hops a single blinded path + // may declare. The spec encodes num_hops as a uint8, so 255 is the + // format's absolute ceiling. + maxBlindedPathHops = math.MaxUint8 + + // maxEncryptedDataLen bounds the encrypted-data field in a single + // blinded hop. The spec encodes the length as a uint16. + maxEncryptedDataLen = math.MaxUint16 + + // minBlindedHopBytes is the on-wire footprint of the smallest possible + // blinded hop: BlindedNodeID(33) + enclen(2) + 0 enc_data. + minBlindedHopBytes = pubKeyLen + 2 +) diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index de5ff4a2302..90f99d264e6 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -263,6 +263,32 @@ func DecodeRecordsP2P(r *bytes.Reader, return tlvStream.DecodeWithParsedTypesP2P(r) } +// AddOpt appends a record producer for the given optional record to producers +// when the optional is set, leaving producers unchanged otherwise. +func AddOpt[T tlv.TlvType, V any](producers *[]tlv.RecordProducer, + opt tlv.OptionalRecordT[T, V]) { + + opt.WhenSome( + func(r tlv.RecordT[T, V]) { + *producers = append(*producers, &r) + }, + ) +} + +// SetOptFromMap marks target as Some(record) when record's TLV type appeared +// on the wire (i.e., is a key in the decoded TypeMap). +// +// The caller must have passed record to the underlying Stream before decoding; +// otherwise record.Val will not have been populated, and wrapping it as Some +// would yield a zero-valued field. +func SetOptFromMap[T tlv.TlvType, V any](typeMap tlv.TypeMap, + target *tlv.OptionalRecordT[T, V], record tlv.RecordT[T, V]) { + + if _, ok := typeMap[record.TlvType()]; ok { + *target = tlv.SomeRecordT(record) + } +} + // AssertUniqueTypes asserts that the given records have unique types. func AssertUniqueTypes(r []tlv.Record) error { seen := make(fn.Set[tlv.Type], len(r)) diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index d4aad2e5462..d14586b8e17 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -249,3 +249,46 @@ func TestCustomRecordsMergedCopy(t *testing.T) { }) } } + +// TestAddOptAppendsOnlyWhenSet checks that AddOpt is a no-op for an empty +// optional and appends a producer when the optional is populated. +func TestAddOptAppendsOnlyWhenSet(t *testing.T) { + t.Parallel() + + var producers []tlv.RecordProducer + + emptyOpt := tlv.OptionalRecordT[tlv.TlvType1, uint16]{} + AddOpt(&producers, emptyOpt) + require.Empty(t, producers) + + setOpt := tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](42), + ) + AddOpt(&producers, setOpt) + require.Len(t, producers, 1) + + rec := producers[0].Record() + require.Equal(t, tlv.Type(1), rec.Type()) +} + +// TestSetOptFromMapUsesTypeMapPresence verifies that SetOptFromMap populates +// only when the TLV type is present in the TypeMap. +func TestSetOptFromMapUsesTypeMapPresence(t *testing.T) { + t.Parallel() + + present := tlv.TypeMap{tlv.Type(1): nil} + missing := tlv.TypeMap{} + + var target tlv.OptionalRecordT[tlv.TlvType1, uint16] + SetOptFromMap( + missing, &target, + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](7), + ) + require.True(t, target.IsNone()) + + SetOptFromMap( + present, &target, + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](7), + ) + require.True(t, target.IsSome()) +} diff --git a/lnwire/intro_node.go b/lnwire/intro_node.go new file mode 100644 index 00000000000..9e4b5a8017c --- /dev/null +++ b/lnwire/intro_node.go @@ -0,0 +1,145 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// IntroductionNode is the sealed sum-type for a blinded path's introduction +// node. {0x02, 0x03} → PubkeyIntro; {0x00, 0x01} → SciddirIntro. The unexported +// method seals the variant set so foreign packages cannot satisfy the interface +// with an unrecognised wire form. +type IntroductionNode interface { + isIntroductionNode() + + encodedLen() uint64 + + encode(w io.Writer) error + + // validate checks that the discriminator byte is valid for the variant. + validate() error + + // Bytes returns the wire-format encoding of the introduction node for + // callers that need it outside an io.Writer (RPC surfaces). + Bytes() []byte +} + +// PubkeyIntro is the 33-byte compressed-pubkey variant. The SEC1 parity byte +// (0x02 or 0x03) doubles as the wire discriminator. +type PubkeyIntro struct { + Pubkey *btcec.PublicKey +} + +// SciddirIntro is the 9-byte sciddir variant. Direction is the wire +// discriminator; SCID is the 8-byte short channel ID. +type SciddirIntro struct { + Direction byte + SCID [scidLen]byte +} + +var ( + _ IntroductionNode = PubkeyIntro{} + _ IntroductionNode = SciddirIntro{} +) + +// decodeIntroductionNode reads the discriminator byte and dispatches to the +// matching variant. +func decodeIntroductionNode(r io.Reader, + buf *[8]byte) (IntroductionNode, error) { + + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return nil, fmt.Errorf("read intro node type: %w", err) + } + + disc := buf[0] + switch disc { + case 0x00, 0x01: + s := SciddirIntro{Direction: disc} + if _, err := io.ReadFull(r, s.SCID[:]); err != nil { + return nil, fmt.Errorf("read sciddir: %w", err) + } + + return s, nil + + case 0x02, 0x03: + var b [pubKeyLen]byte + b[0] = disc + if _, err := io.ReadFull(r, b[1:]); err != nil { + return nil, fmt.Errorf("read intro pubkey: %w", err) + } + pub, err := btcec.ParsePubKey(b[:]) + if err != nil { + return nil, fmt.Errorf("%w: %w", + ErrInvalidIntroNode, err) + } + + return PubkeyIntro{Pubkey: pub}, nil + + default: + return nil, fmt.Errorf("%w: 0x%02x", ErrInvalidIntroNode, disc) + } +} + +func (PubkeyIntro) isIntroductionNode() {} + +func (p PubkeyIntro) encodedLen() uint64 { return pubKeyLen } + +func (p PubkeyIntro) encode(w io.Writer) error { + if p.Pubkey == nil { + return fmt.Errorf("nil intro pubkey") + } + _, err := w.Write(p.Pubkey.SerializeCompressed()) + + return err +} + +func (p PubkeyIntro) validate() error { + if p.Pubkey == nil { + return fmt.Errorf("%w: nil pubkey", ErrInvalidIntroNode) + } + + return nil +} + +// Bytes returns the wire-format encoding of the pubkey variant. +func (p PubkeyIntro) Bytes() []byte { + var buf bytes.Buffer + buf.Grow(pubKeyLen) + _ = p.encode(&buf) + + return buf.Bytes() +} + +func (SciddirIntro) isIntroductionNode() {} + +func (s SciddirIntro) encodedLen() uint64 { return sciddirLen } + +func (s SciddirIntro) encode(w io.Writer) error { + if _, err := w.Write([]byte{s.Direction}); err != nil { + return err + } + _, err := w.Write(s.SCID[:]) + + return err +} + +func (s SciddirIntro) validate() error { + switch s.Direction { + case 0x00, 0x01: + return nil + } + + return fmt.Errorf("%w: 0x%02x", ErrInvalidIntroNode, s.Direction) +} + +// Bytes returns the wire-format encoding of the sciddir variant. +func (s SciddirIntro) Bytes() []byte { + var buf bytes.Buffer + buf.Grow(sciddirLen) + _ = s.encode(&buf) + + return buf.Bytes() +} diff --git a/lnwire/onion_msg_payload.go b/lnwire/onion_msg_payload.go index 64c5abadec3..7aeeece5957 100644 --- a/lnwire/onion_msg_payload.go +++ b/lnwire/onion_msg_payload.go @@ -7,7 +7,6 @@ import ( "io" "sort" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/tlv" ) @@ -36,22 +35,15 @@ const ( InvoiceErrorNamespaceType tlv.Type = 68 ) -var ( - // ErrNotFinalPayload is returned when a final hop payload is not - // within the correct range. - ErrNotFinalPayload = errors.New("final hop payloads type should be " + - ">= 64") - - // ErrNoHops is returned when we handle a reply path that does not - // have any hops (this makes no sense). - ErrNoHops = errors.New("reply path requires hops") -) +// ErrNotFinalPayload is returned when a final hop payload is not within the +// correct range. +var ErrNotFinalPayload = errors.New("final hop payloads type should be >= 64") // OnionMessagePayload contains the contents of an onion message payload. type OnionMessagePayload struct { // ReplyPath contains a blinded path that can be used to respond to an // onion message. - ReplyPath *sphinx.BlindedPath + ReplyPath *BlindedPath // EncryptedData contains encrypted data for the recipient. EncryptedData []byte @@ -73,7 +65,7 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) { var records []tlv.Record if o.ReplyPath != nil { - records = append(records, replyPathRecord(o.ReplyPath)) + records = append(records, o.ReplyPath.Record()) } if len(o.EncryptedData) != 0 { @@ -132,10 +124,10 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (map[tlv.Type][]byte, error) { } ) // Create a non-nil entry so that we can directly decode into it. - o.ReplyPath = &sphinx.BlindedPath{} + o.ReplyPath = &BlindedPath{} records := []tlv.Record{ - replyPathRecord(o.ReplyPath), + o.ReplyPath.Record(), tlv.MakePrimitiveRecord( encryptedDataTLVType, &o.EncryptedData, ), @@ -258,155 +250,3 @@ func (f *FinalHopTLV) Validate() error { return nil } - -// replyPathRecord produces a tlv record for a reply path. -func replyPathRecord(r *sphinx.BlindedPath) tlv.Record { - return tlv.MakeDynamicRecord( - replyPathType, r, replyPathSize(r), encodeReplyPath, - decodeReplyPath, - ) -} - -// replyPathSize returns the encoded size of a reply path. -func replyPathSize(r *sphinx.BlindedPath) func() uint64 { - return func() uint64 { - // First node pubkey 33 + blinding point pubkey 33 + 1 byte for - // uint8 for our hop count. - size := uint64(33 + 33 + 1) - - // Add each hop's size to our total. - for _, hop := range r.BlindedHops { - size += blindedHopSize(hop) - } - - return size - } -} - -// encodeReplyPath encodes a reply path tlv. -func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error { - if p, ok := val.(*sphinx.BlindedPath); ok { - err := tlv.EPubKey(w, &p.IntroductionPoint, buf) - if err != nil { - return fmt.Errorf("encode first node id: %w", err) - } - - if err := tlv.EPubKey(w, &p.BlindingPoint, buf); err != nil { - return fmt.Errorf("encode blinding point: %w", err) - } - - hopCount := uint8(len(p.BlindedHops)) - if hopCount == 0 { - return ErrNoHops - } - - if err := tlv.EUint8(w, &hopCount, buf); err != nil { - return fmt.Errorf("encode hop count: %w", err) - } - - for i, hop := range p.BlindedHops { - if err := encodeBlindedHop(w, hop, buf); err != nil { - return fmt.Errorf("hop %v: %w", i, err) - } - } - - return nil - } - - return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedPath") -} - -// decodeReplyPath decodes a reply path tlv. -func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte, - l uint64) error { - - // If we have the correct type, and the length exceeds the fixed header - // size (first node pubkey (33) + blinding point (33) + hop count (1) = - // 67 bytes) to accommodate at least one hop, decode the reply path. - if p, ok := val.(*sphinx.BlindedPath); ok && l > 67 { - err := tlv.DPubKey(r, &p.IntroductionPoint, buf, 33) - if err != nil { - return fmt.Errorf("decode first id: %w", err) - } - - err = tlv.DPubKey(r, &p.BlindingPoint, buf, 33) - if err != nil { - return fmt.Errorf("decode blinding point: %w", err) - } - - var hopCount uint8 - if err := tlv.DUint8(r, &hopCount, buf, 1); err != nil { - return fmt.Errorf("decode hop count: %w", err) - } - - if hopCount == 0 { - return ErrNoHops - } - - for i := 0; i < int(hopCount); i++ { - hop := &sphinx.BlindedHopInfo{} - if err := decodeBlindedHop(r, hop, buf); err != nil { - return fmt.Errorf("decode hop: %w", err) - } - - p.BlindedHops = append(p.BlindedHops, hop) - } - - return nil - } - - return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedPath", l, l) -} - -// blindedHopSize returns the encoded size of a blinded hop. -func blindedHopSize(b *sphinx.BlindedHopInfo) uint64 { - // 33 byte pubkey + 2 bytes uint16 length + var bytes. - return uint64(33 + 2 + len(b.CipherText)) -} - -// encodeBlindedHop encodes a blinded hop tlv. -func encodeBlindedHop(w io.Writer, val interface{}, buf *[8]byte) error { - if b, ok := val.(*sphinx.BlindedHopInfo); ok { - if err := tlv.EPubKey(w, &b.BlindedNodePub, buf); err != nil { - return fmt.Errorf("encode blinded id: %w", err) - } - - dataLen := uint16(len(b.CipherText)) - if err := tlv.EUint16(w, &dataLen, buf); err != nil { - return fmt.Errorf("data len: %w", err) - } - - if err := tlv.EVarBytes(w, &b.CipherText, buf); err != nil { - return fmt.Errorf("encode encrypted data: %w", err) - } - - return nil - } - - return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedHopInfo") -} - -// decodeBlindedHop decodes a blinded hop tlv. -func decodeBlindedHop(r io.Reader, val interface{}, buf *[8]byte) error { - if b, ok := val.(*sphinx.BlindedHopInfo); ok { - err := tlv.DPubKey(r, &b.BlindedNodePub, buf, 33) - if err != nil { - return fmt.Errorf("decode blinded id: %w", err) - } - - var dataLen uint16 - err = tlv.DUint16(r, &dataLen, buf, 2) - if err != nil { - return fmt.Errorf("decode data len: %w", err) - } - - err = tlv.DVarBytes(r, &b.CipherText, buf, uint64(dataLen)) - if err != nil { - return fmt.Errorf("decode data: %w", err) - } - - return nil - } - - return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedHopInfo", 0, 0) -} diff --git a/lnwire/onion_msg_payload_test.go b/lnwire/onion_msg_payload_test.go index 3a85ec60861..6871f9926bb 100644 --- a/lnwire/onion_msg_payload_test.go +++ b/lnwire/onion_msg_payload_test.go @@ -5,15 +5,14 @@ import ( "fmt" "testing" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" "pgregory.net/rapid" ) // makeBlindedPath creates a BlindedPath with the given number of hops for -// testing. Each hop has a random blinded node pub and some cipher text. -func makeBlindedPath(t *testing.T, numHops int) *sphinx.BlindedPath { +// testing. Each hop has a random blinded node ID and some cipher text. +func makeBlindedPath(t *testing.T, numHops int) *BlindedPath { t.Helper() introKey, err := randPubKey() @@ -22,55 +21,48 @@ func makeBlindedPath(t *testing.T, numHops int) *sphinx.BlindedPath { blindingKey, err := randPubKey() require.NoError(t, err) - hops := make([]*sphinx.BlindedHopInfo, numHops) + hops := make([]BlindedHop, numHops) for i := range hops { nodePub, err := randPubKey() require.NoError(t, err) - hops[i] = &sphinx.BlindedHopInfo{ - BlindedNodePub: nodePub, - CipherText: bytes.Repeat([]byte{byte(i + 1)}, 32), - } + hops[i].BlindedNodeID = nodePub + hops[i].EncryptedData = bytes.Repeat([]byte{byte(i + 1)}, 32) } - return &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: hops, + return &BlindedPath{ + IntroductionNode: PubkeyIntro{Pubkey: introKey}, + BlindingPoint: blindingKey, + Hops: hops, } } -// assertBlindedPathEqual compares two BlindedPaths for equality, checking each -// field. -func assertBlindedPathEqual(t *testing.T, expected, - actual *sphinx.BlindedPath) { - +// assertBlindedPathEqual compares two BlindedPaths field-by-field. Direct +// require.Equal would also work, but the per-field assertions surface +// localised mismatches for easier triage. +func assertBlindedPathEqual(t *testing.T, expected, actual *BlindedPath) { t.Helper() - require.True( - t, - expected.IntroductionPoint.IsEqual(actual.IntroductionPoint), - "IntroductionPoint mismatch", + require.Equal( + t, expected.IntroductionNode, actual.IntroductionNode, + "IntroductionNode mismatch", ) - require.True( - t, expected.BlindingPoint.IsEqual(actual.BlindingPoint), + require.Equal( + t, expected.BlindingPoint, actual.BlindingPoint, "BlindingPoint mismatch", ) - require.Len(t, actual.BlindedHops, len(expected.BlindedHops)) - - for i, expectedHop := range expected.BlindedHops { - actualHop := actual.BlindedHops[i] + require.Len(t, actual.Hops, len(expected.Hops)) - require.True( - t, - expectedHop.BlindedNodePub.IsEqual( - actualHop.BlindedNodePub, - ), - "hop %d: BlindedNodePub mismatch", i, + for i := range expected.Hops { + require.Equal( + t, expected.Hops[i].BlindedNodeID, + actual.Hops[i].BlindedNodeID, + "hop %d: BlindedNodeID mismatch", i, ) require.Equal( - t, expectedHop.CipherText, actualHop.CipherText, - "hop %d: CipherText mismatch", i, + t, expected.Hops[i].EncryptedData, + actual.Hops[i].EncryptedData, + "hop %d: EncryptedData mismatch", i, ) } } @@ -112,6 +104,29 @@ func TestOnionMessagePayloadRoundTrip(t *testing.T) { require.Empty(t, decoded.FinalHopTLVs) }) + t.Run("sciddir intro reply path", func(t *testing.T) { + t.Parallel() + + path := makeBlindedPath(t, 2) + path.IntroductionNode = SciddirIntro{ + Direction: 0x01, + SCID: [scidLen]byte{ + 0x00, 0x11, 0x22, 0x33, + 0x44, 0x55, 0x66, 0x77, + }, + } + + original := &OnionMessagePayload{ReplyPath: path} + + decoded := encodeAndDecode(t, original) + + require.NotNil(t, decoded.ReplyPath) + require.IsType( + t, SciddirIntro{}, decoded.ReplyPath.IntroductionNode, + ) + assertBlindedPathEqual(t, original.ReplyPath, decoded.ReplyPath) + }) + t.Run("only encrypted data", func(t *testing.T) { t.Parallel() @@ -351,15 +366,15 @@ func TestOnionMessagePayloadEncodeReplyPathNoHops(t *testing.T) { require.NoError(t, err) payload := &OnionMessagePayload{ - ReplyPath: &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: nil, + ReplyPath: &BlindedPath{ + IntroductionNode: PubkeyIntro{Pubkey: introKey}, + BlindingPoint: blindingKey, + Hops: nil, }, } _, err = payload.Encode() - require.ErrorIs(t, err, ErrNoHops) + require.ErrorIs(t, err, ErrEmptyBlindedPath) } // TestOnionMessagePayloadEmpty tests that an empty payload roundtrips @@ -442,35 +457,9 @@ func TestOnionMessagePayloadRoundTripQuickCheck(t *testing.T) { require.Nil(t, decoded.ReplyPath) } else { require.NotNil(t, decoded.ReplyPath) - require.True( - t, - original.ReplyPath.IntroductionPoint.IsEqual( - decoded.ReplyPath.IntroductionPoint, - ), - ) - require.True( - t, - original.ReplyPath.BlindingPoint.IsEqual( - decoded.ReplyPath.BlindingPoint, - ), - ) - require.Len( - t, decoded.ReplyPath.BlindedHops, - len(original.ReplyPath.BlindedHops), + require.Equal( + t, original.ReplyPath, decoded.ReplyPath, ) - for i, hop := range original.ReplyPath.BlindedHops { - dHop := decoded.ReplyPath.BlindedHops[i] - require.True( - t, - hop.BlindedNodePub.IsEqual( - dHop.BlindedNodePub, - ), - ) - require.Equal( - t, hop.CipherText, - dHop.CipherText, - ) - } } // Verify encrypted data. diff --git a/lnwire/pure_tlv.go b/lnwire/pure_tlv.go index 8e6f7bd9fc3..6692ac2f53c 100644 --- a/lnwire/pure_tlv.go +++ b/lnwire/pure_tlv.go @@ -23,12 +23,12 @@ const ( ) // PureTLVMessage describes an LN message that is a pure TLV stream. If the -// message includes a signature, it will sign all the TLV records in the -// inclusive ranges: 0 to 159 and 1000000000 to 2999999999. +// message includes a signature, the signature covers a subset of the records, +// which subset is determined by the protocol's signed/unsigned range (see +// SerialiseFieldsToSignFn). type PureTLVMessage interface { - // AllRecords returns all the TLV records for the message. This will - // include all the records we know about along with any that we don't - // know about but that fall in the signed TLV range. + // AllRecords returns all the TLV records for the message, including + // both records we know about and unknown records that we preserve. AllRecords() []tlv.Record } @@ -37,13 +37,27 @@ func EncodePureTLVMessage(msg PureTLVMessage, buf *bytes.Buffer) error { return EncodeRecordsTo(buf, msg.AllRecords()) } +// UnsignedRangeFunc returns true when a TLV type is in the unsigned range of a +// pure-TLV message (i.e., excluded from the signature). Each protocol supplies +// its own predicate to encode the boundary between signed and unsigned types. +type UnsignedRangeFunc func(tlv.Type) bool + // SerialiseFieldsToSign serialises all the records from the given -// PureTLVMessage that fall within the signed TLV range. +// PureTLVMessage that fall within the BOLT 7 v2 signed TLV range. Use +// SerialiseFieldsToSignFn for a protocol with a different boundary. func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { - // Filter out all the fields not in the signed ranges. + return SerialiseFieldsToSignFn(msg, InUnsignedRange) +} + +// SerialiseFieldsToSignFn serialises all the records from the given +// PureTLVMessage that the supplied predicate keeps in the signed range. A type +// for which isUnsigned returns true is excluded from the digest. +func SerialiseFieldsToSignFn(msg PureTLVMessage, + isUnsigned UnsignedRangeFunc) ([]byte, error) { + var signedRecords []tlv.Record for _, record := range msg.AllRecords() { - if InUnsignedRange(record.Type()) { + if isUnsigned(record.Type()) { continue } @@ -58,8 +72,9 @@ func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { return buf.Bytes(), nil } -// InUnsignedRange returns true if the given TLV type falls outside the TLV -// ranges that the signature of a pure TLV message will cover. +// InUnsignedRange is the BOLT 7 v2 UnsignedRangeFunc: it returns true for types +// in 160-999_999_999 or 3_000_000_000+, which sit outside the BOLT 7 v2 signed +// ranges (0-159 and 1_000_000_000-2_999_999_999). func InUnsignedRange(t tlv.Type) bool { return (t >= pureTLVUnsignedRangeOneStart && t < pureTLVSignedSecondRangeStart) || @@ -72,32 +87,41 @@ func InUnsignedRange(t tlv.Type) bool { // for re-composing the wire message since the signature covers these fields. type ExtraSignedFields map[uint64][]byte -// ExtraSignedFieldsFromTypeMap is a helper that can be used alongside calls to -// the tlv.Stream DecodeWithParsedTypesP2P or DecodeWithParsedTypes methods to -// extract the tlv type and value pairs in the defined PureTLVMessage signed -// range which we have not handled with any of our defined Records. These -// methods will return a tlv.TypeMap containing the records that were extracted -// from an io.Reader. If the record was know and handled by a defined record, -// then the value accompanying the record's type in the map will be nil. -// Otherwise, if the record was unhandled, it will be non-nil. +// ExtraSignedFieldsFromTypeMap returns the unhandled signed-range entries from +// a tlv.TypeMap (as returned by DecodeWithParsedTypes(P2P)) so the caller can +// re-emit them and keep the message signature valid. It uses the BOLT 7 v2 +// signed range; use ExtraSignedFieldsFromTypeMapFn for a different boundary. func ExtraSignedFieldsFromTypeMap(m tlv.TypeMap) ExtraSignedFields { + return ExtraSignedFieldsFromTypeMapFn(m, InUnsignedRange) +} + +// ExtraSignedFieldsFromTypeMapFn returns the unhandled entries from a +// tlv.TypeMap that the supplied predicate keeps in the signed range, so the +// caller can re-emit them and keep the message signature valid. Entries for +// which isUnsigned returns true are dropped. +func ExtraSignedFieldsFromTypeMapFn(m tlv.TypeMap, + isUnsigned UnsignedRangeFunc) ExtraSignedFields { + extraFields := make(ExtraSignedFields) for t, v := range m { - // If the value in the type map is nil, then it indicates that - // we know this type, and it was handled by one of the records - // we passed to the decode function vai the TLV stream. + // A nil value signals that this type was consumed by one of the + // typed records passed to the TLV stream decoder, so its bytes + // are already represented elsewhere and do not need to be + // tracked here. if v == nil { continue } - // No need to keep this field if it is unknown to us and is not - // in the sign range. - if InUnsignedRange(t) { + // Types the predicate places outside the signed range fall + // outside the signature's coverage, so they do not need to + // survive into re-encoding. + if isUnsigned(t) { continue } - // Otherwise, this is an un-handled type, so we keep track of - // it for signature validation and re-encoding later on. + // The remaining types are unhandled but within the signed + // range; preserve their raw bytes so the message can re-emit + // them verbatim and the signature stays valid. extraFields[uint64(t)] = v } diff --git a/lnwire/pure_tlv_test.go b/lnwire/pure_tlv_test.go index a81a89ecb6d..9148678d2a0 100644 --- a/lnwire/pure_tlv_test.go +++ b/lnwire/pure_tlv_test.go @@ -387,3 +387,89 @@ func (g *MsgV2) AllRecords() []tlv.Record { return ProduceRecordsSorted(recordProducers...) } + +// mockPureTLVMessage is a minimal PureTLVMessage backed by a fixed record +// slice, used to exercise the predicate-driven helpers. +type mockPureTLVMessage struct { + records []tlv.Record +} + +func (m *mockPureTLVMessage) AllRecords() []tlv.Record { + return m.records +} + +// TestSerialiseFieldsToSignFn verifies that the serialiser correctly filters +// records based on the provided predicate before encoding. +func TestSerialiseFieldsToSignFn(t *testing.T) { + t.Parallel() + + var ( + signedVal uint16 = 11 + unsignedVal uint16 = 22 + ) + + msg := &mockPureTLVMessage{ + records: []tlv.Record{ + tlv.MakePrimitiveRecord(5, &signedVal), + tlv.MakePrimitiveRecord(10, &unsignedVal), + }, + } + + // Predicate that defines type 10 as unsigned (excluded). + isUnsigned := func(typ tlv.Type) bool { + return typ == 10 + } + + encoded, err := SerialiseFieldsToSignFn(msg, isUnsigned) + require.NoError(t, err) + + // Only type 5 should be encoded (type 5, length 2, value 11). + require.Equal(t, []byte{0x05, 0x02, 0x00, 0x0b}, encoded) +} + +// TestExtraSignedFieldsFromTypeMapFn confirms the predicate-driven variant +// keeps and drops the right type ranges for callers whose signed range is not +// the BOLT 7 v2 default. It also locks in the round-trip identity with the +// convenience wrapper. +func TestExtraSignedFieldsFromTypeMapFn(t *testing.T) { + t.Parallel() + + // Bolt12 signature TLVs sit at 240-1000 and are excluded from the + // signed Merkle tree. Everything else is signed. + bolt12Unsigned := func(typ tlv.Type) bool { + return typ >= 240 && typ <= 1000 + } + + typeMap := tlv.TypeMap{ + // Handled by a typed record on the receiver. + tlv.Type(2): nil, + + // Unknown type in the bolt12 signed range — must survive. + tlv.Type(99): { + 0x01, + }, + + // Bolt12 signature TLV — must be dropped. + tlv.Type(240): { + 0x02, + }, + + // Bolt12 second-range type — signed for bolt12, signed for the + // BOLT 7 v2 default too. + tlv.Type(1_500_000_000): { + 0x03, + }, + } + + gotBolt12 := ExtraSignedFieldsFromTypeMapFn(typeMap, bolt12Unsigned) + require.Len(t, gotBolt12, 2) + require.Equal(t, []byte{0x01}, gotBolt12[99]) + require.Equal(t, []byte{0x03}, gotBolt12[1_500_000_000]) + + gotDefault := ExtraSignedFieldsFromTypeMap(typeMap) + // In the BOLT 7 v2 range, type 99 is signed but type 240 is unsigned. + require.Len(t, gotDefault, 2) + require.Equal(t, []byte{0x01}, gotDefault[99]) + require.Equal(t, []byte{0x03}, gotDefault[1_500_000_000]) + require.NotContains(t, gotDefault, uint64(240)) +} diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 602724a9dd6..ae70aeeb543 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/require" "pgregory.net/rapid" @@ -59,30 +58,48 @@ func RandPubKey(t *rapid.T) *btcec.PublicKey { return pub } -// RandBlindedPath generates a random blinded path with 1-5 hops. -func RandBlindedPath(t *rapid.T) *sphinx.BlindedPath { - introKey := RandPubKey(t) - blindingKey := RandPubKey(t) +// RandBlindedPath generates a random blinded path with 1-5 hops, alternating +// between the pubkey and sciddir introduction-node variants per draw. +func RandBlindedPath(t *rapid.T) *BlindedPath { + useSciddir := rapid.Bool().Draw(t, "introIsSciddir") + + var intro IntroductionNode + if useSciddir { + var scid [scidLen]byte + copy(scid[:], rapid.SliceOfN( + rapid.Byte(), scidLen, scidLen, + ).Draw(t, "introScid")) + + intro = SciddirIntro{ + Direction: byte( + rapid.IntRange(0, 1).Draw(t, "introDir"), + ), + SCID: scid, + } + } else { + intro = PubkeyIntro{Pubkey: RandPubKey(t)} + } + + blindingPoint := RandPubKey(t) numHops := rapid.IntRange(1, 5).Draw(t, "numBlindedHops") - hops := make([]*sphinx.BlindedHopInfo, numHops) + hops := make([]BlindedHop, numHops) for i := range hops { cipherLen := rapid.IntRange(1, 128).Draw( t, fmt.Sprintf("cipherLen-%d", i), ) - hops[i] = &sphinx.BlindedHopInfo{ - BlindedNodePub: RandPubKey(t), - CipherText: rapid.SliceOfN( - rapid.Byte(), cipherLen, cipherLen, - ).Draw(t, fmt.Sprintf("cipherText-%d", i)), - } + hops[i].BlindedNodeID = RandPubKey(t) + + hops[i].EncryptedData = rapid.SliceOfN( + rapid.Byte(), cipherLen, cipherLen, + ).Draw(t, fmt.Sprintf("cipherText-%d", i)) } - return &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: hops, + return &BlindedPath{ + IntroductionNode: intro, + BlindingPoint: blindingPoint, + Hops: hops, } } diff --git a/onionmessage/onion_endpoint.go b/onionmessage/onion_endpoint.go index f6a32d2ebd9..0c9829e2463 100644 --- a/onionmessage/onion_endpoint.go +++ b/onionmessage/onion_endpoint.go @@ -1,7 +1,7 @@ package onionmessage import ( - sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) @@ -26,7 +26,7 @@ type OnionMessageUpdate struct { CustomRecords record.CustomSet // ReplyPath contains the reply path information for the onion message. - ReplyPath *sphinx.BlindedPath + ReplyPath *lnwire.BlindedPath // EncryptedRecipientData contains the encrypted recipient data for the // onion message, created by the creator of the blinded route. This is diff --git a/routing/route/blindedroute.go b/routing/route/blindedroute.go index 2b8120ad6e2..35ad07031a5 100644 --- a/routing/route/blindedroute.go +++ b/routing/route/blindedroute.go @@ -13,7 +13,7 @@ import ( // payloads used to encoding the routing data for each hop in the route. This // method also accepts final hop payloads. func OnionMessageBlindedPathToSphinxPath(blindedPath *sphinx.BlindedPath, - replyPath *sphinx.BlindedPath, finalHopTLVs []*lnwire.FinalHopTLV) ( + replyPath *lnwire.BlindedPath, finalHopTLVs []*lnwire.FinalHopTLV) ( *sphinx.PaymentPath, error) { var path sphinx.PaymentPath diff --git a/rpcserver.go b/rpcserver.go index d4a9ce51bb5..00d6263a04b 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -9570,15 +9570,16 @@ func (r *rpcServer) SubscribeOnionMessages( //nolint:ll if oMsg.ReplyPath != nil { - bp.IntroductionNode = oMsg.ReplyPath.IntroductionPoint.SerializeCompressed() + bp.IntroductionNode = oMsg.ReplyPath.IntroductionNode.Bytes() bp.BlindingPoint = oMsg.ReplyPath.BlindingPoint.SerializeCompressed() - for _, hop := range oMsg.ReplyPath.BlindedHops { - rpcHop := &lnrpc.BlindedHop{ - BlindedNode: hop.BlindedNodePub.SerializeCompressed(), - EncryptedData: hop.CipherText, - } - bp.BlindedHops = append(bp.BlindedHops, rpcHop) + for _, hop := range oMsg.ReplyPath.Hops { + bp.BlindedHops = append( + bp.BlindedHops, &lnrpc.BlindedHop{ + BlindedNode: hop.BlindedNodeID.SerializeCompressed(), + EncryptedData: hop.EncryptedData, + }, + ) } }