Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions payments/db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ type PaymentControl interface {
// DBMPPayment is an interface that represents the payment state during a
// payment lifecycle.
type DBMPPayment interface {
// GetSequenceNum returns the payment's unique sequence number.
GetSequenceNum() uint64

// GetState returns the current state of the payment.
GetState() *MPPaymentState

Expand Down
5 changes: 5 additions & 0 deletions payments/db/payment.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ type MPPayment struct {
State *MPPaymentState
}

// GetSequenceNum returns the payment's unique sequence number.
func (m *MPPayment) GetSequenceNum() uint64 {
return m.SequenceNum
}

// Terminated returns a bool to specify whether the payment is in a terminal
// state.
func (m *MPPayment) Terminated() bool {
Expand Down
14 changes: 12 additions & 2 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ type failPaymentArgs struct {
}

type testPayment struct {
sequence uint64
info paymentsdb.PaymentCreationInfo
attempts []paymentsdb.HTLCAttempt
}
Expand All @@ -283,6 +284,7 @@ type mockControlTowerOld struct {
failAttempt chan failAttemptArgs
failPayment chan failPaymentArgs
fetchInFlight chan struct{}
sequence uint64

sync.Mutex
}
Expand Down Expand Up @@ -320,9 +322,11 @@ func (m *mockControlTowerOld) InitPayment(_ context.Context,
return paymentsdb.ErrPaymentInFlight
}

m.sequence++
delete(m.failed, phash)
m.payments[phash] = &testPayment{
info: *c,
sequence: m.sequence,
info: *c,
}

return nil
Expand Down Expand Up @@ -531,7 +535,8 @@ func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) (
}

mp := &paymentsdb.MPPayment{
Info: &p.info,
SequenceNum: p.sequence,
Info: &p.info,
}

reason, ok := m.failed[phash]
Expand Down Expand Up @@ -834,6 +839,11 @@ type mockMPPayment struct {

var _ paymentsdb.DBMPPayment = (*mockMPPayment)(nil)

func (m *mockMPPayment) GetSequenceNum() uint64 {
args := m.Called()
return args.Get(0).(uint64)
}

func (m *mockMPPayment) GetState() *paymentsdb.MPPaymentState {
args := m.Called()
return args.Get(0).(*paymentsdb.MPPaymentState)
Expand Down
41 changes: 38 additions & 3 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ import (
"github.com/lightningnetwork/lnd/tlv"
)

// ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but
// the payment lifecycle is exiting .
var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting")
var (
// ErrPaymentLifecycleExiting is used when waiting for htlc attempt
// result, but the payment lifecycle is exiting .
ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting")

// ErrPaymentLifecycleStale is returned when a payment lifecycle detects
// that its payment hash now points to a newer payment sequence number.
ErrPaymentLifecycleStale = errors.New("payment lifecycle stale")
)

// switchResult is the result sent back from the switch after processing the
// HTLC.
Expand All @@ -41,6 +47,7 @@ type paymentLifecycle struct {
router *ChannelRouter
feeLimit lnwire.MilliSatoshi
identifier lntypes.Hash
sequenceNum uint64
paySession PaymentSession
shardTracker shards.ShardTracker
currentHeight int32
Expand Down Expand Up @@ -214,6 +221,10 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
return [32]byte{}, nil, err
}

if err := p.checkPaymentSequenceNum(payment); err != nil {
return [32]byte{}, nil, err
}

// Get the payment status.
status := payment.GetStatus()

Expand Down Expand Up @@ -359,6 +370,27 @@ lifecycle:
return [32]byte{}, nil, *failure
}

// checkPaymentSequenceNum binds the lifecycle to its first observed payment
// sequence number, then returns an error if later reloads observe a different
// sequence number for the same payment hash.
func (p *paymentLifecycle) checkPaymentSequenceNum(
payment paymentsdb.DBMPPayment) error {

sequenceNum := payment.GetSequenceNum()
if p.sequenceNum == 0 {
p.sequenceNum = sequenceNum
return nil
}

if sequenceNum == p.sequenceNum {
return nil
}

return fmt.Errorf("%w: payment %v sequence number changed from %d "+
"to %d", ErrPaymentLifecycleStale, p.identifier,
p.sequenceNum, sequenceNum)
}

// checkContext checks whether the payment context has been canceled.
// Cancellation occurs manually or if the context times out.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
Expand Down Expand Up @@ -1173,6 +1205,9 @@ func (p *paymentLifecycle) reloadPayment(
if err != nil {
return nil, nil, err
}
if err := p.checkPaymentSequenceNum(payment); err != nil {
return nil, nil, err
}

ps := payment.GetState()
remainingFees := p.calcFeeBudget(ps.FeesPaid)
Expand Down
24 changes: 24 additions & 0 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {

htlcs := []paymentsdb.HTLCAttempt{}
m.payment.On("InFlightHTLCs").Return(htlcs).Once()
m.payment.On("GetSequenceNum").Return(uint64(1))

return p, m
}
Expand Down Expand Up @@ -801,6 +802,29 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) {
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailOnStaleSequenceNum checks that a lifecycle exits before
// making more attempts when its payment hash has been recycled by a newer
// InitPayment.
func TestResumePaymentFailOnStaleSequenceNum(t *testing.T) {
t.Parallel()

p, m := setupTestPaymentLifecycle(t)

m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight).Once()
newPayment := &mockMPPayment{}
m.control.On("FetchPayment", p.identifier).Return(
newPayment, nil,
).Once()
newPayment.On("GetSequenceNum").Return(uint64(2)).Once()
defer newPayment.AssertExpectations(t)

sendPaymentAndAssertError(
t, t.Context(), p, ErrPaymentLifecycleStale,
)

require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailOnTimeout checks that when timeout is reached, the
// payment is failed.
//
Expand Down
35 changes: 34 additions & 1 deletion routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1145,12 +1145,42 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
FirstHopCustomRecords: firstHopCustomRecords,
}

err := r.cfg.Control.InitPayment(ctx, paymentIdentifier, info)
var dbPayment paymentsdb.DBMPPayment
err := r.cfg.Control.InitPayment(
ctx, paymentIdentifier, info,
)
switch {
case err == nil:
payment, err := r.cfg.Control.FetchPayment(
ctx, paymentIdentifier,
)
if err != nil {
return nil, err
}

dbPayment = payment

// If this is an MPP attempt and the hash is already registered with
// the database, we can go on to launch the shard.
case mpp != nil && errors.Is(err, paymentsdb.ErrPaymentInFlight):
payment, err := r.cfg.Control.FetchPayment(
ctx, paymentIdentifier,
)
if err != nil {
return nil, err
}

dbPayment = payment

case mpp != nil && errors.Is(err, paymentsdb.ErrPaymentExists):
payment, err := r.cfg.Control.FetchPayment(
ctx, paymentIdentifier,
)
if err != nil {
return nil, err
}

dbPayment = payment

// Any other error is not tolerated.
case err != nil:
Expand All @@ -1176,6 +1206,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
r, 0, paymentIdentifier, nil, shardTracker, 0,
firstHopCustomRecords,
)
if err := p.checkPaymentSequenceNum(dbPayment); err != nil {
return nil, err
}

// Allow the traffic shaper to add custom records to the outgoing HTLC
// and also adjust the amount if needed.
Expand Down
19 changes: 15 additions & 4 deletions routing/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,9 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {

// Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
payment := &mockMPPayment{}
controlTower.On("FetchPayment", payHash).Return(payment, nil).Once()
payment.On("GetSequenceNum").Return(uint64(1)).Once()
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
controlTower.On("SettleAttempt",
payHash, mock.Anything, mock.Anything,
Expand Down Expand Up @@ -2253,6 +2256,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}

// TestSendToRouteSkipTempErrNonMPP checks that an error is return when
Expand Down Expand Up @@ -2366,6 +2370,9 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {

// Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
payment := &mockMPPayment{}
controlTower.On("FetchPayment", payHash).Return(payment, nil).Once()
payment.On("GetSequenceNum").Return(uint64(1)).Once()
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
controlTower.On("FailAttempt",
payHash, mock.Anything, mock.Anything,
Expand All @@ -2392,6 +2399,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}

// TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure
Expand Down Expand Up @@ -2446,6 +2454,9 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) {

// Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
payment := &mockMPPayment{}
controlTower.On("FetchPayment", payHash).Return(payment, nil).Once()
payment.On("GetSequenceNum").Return(uint64(1)).Once()
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)

controlTower.On("FailAttempt",
Expand Down Expand Up @@ -2478,6 +2489,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) {
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}

// TestSendToRouteTempFailure validates a temporary failure will cause the
Expand Down Expand Up @@ -2532,6 +2544,9 @@ func TestSendToRouteTempFailure(t *testing.T) {

// Register mockers with the expected method calls.
controlTower.On("InitPayment", payHash, mock.Anything).Return(nil)
payment := &mockMPPayment{}
controlTower.On("FetchPayment", payHash).Return(payment, nil).Twice()
payment.On("GetSequenceNum").Return(uint64(1)).Once()
controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil)
controlTower.On("FailAttempt",
payHash, mock.Anything, mock.Anything,
Expand All @@ -2546,10 +2561,6 @@ func TestSendToRouteTempFailure(t *testing.T) {
mock.Anything, mock.Anything, mock.Anything,
).Return(tempErr)

// Mock the control tower to return the mocked payment.
payment := &mockMPPayment{}
controlTower.On("FetchPayment", payHash).Return(payment, nil).Once()

// Mock the payment to return nil failure reason.
payment.On("TerminalInfo").Return(nil, nil).Once()

Expand Down
Loading