diff --git a/payments/db/interface.go b/payments/db/interface.go index 6edaa7f45b5..3d3ee4a4187 100644 --- a/payments/db/interface.go +++ b/payments/db/interface.go @@ -110,6 +110,9 @@ type PaymentControl interface { // DBMPPayment is an interface that represents the payment state during a // payment lifecycle. type DBMPPayment interface { + // GetSequenceNum returns the payment's unique sequence number. + GetSequenceNum() uint64 + // GetState returns the current state of the payment. GetState() *MPPaymentState diff --git a/payments/db/payment.go b/payments/db/payment.go index ddceedfb0f0..d045497935a 100644 --- a/payments/db/payment.go +++ b/payments/db/payment.go @@ -359,6 +359,11 @@ type MPPayment struct { State *MPPaymentState } +// GetSequenceNum returns the payment's unique sequence number. +func (m *MPPayment) GetSequenceNum() uint64 { + return m.SequenceNum +} + // Terminated returns a bool to specify whether the payment is in a terminal // state. func (m *MPPayment) Terminated() bool { diff --git a/routing/mock_test.go b/routing/mock_test.go index 472f1261623..0cafaa55375 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -268,6 +268,7 @@ type failPaymentArgs struct { } type testPayment struct { + sequence uint64 info paymentsdb.PaymentCreationInfo attempts []paymentsdb.HTLCAttempt } @@ -283,6 +284,7 @@ type mockControlTowerOld struct { failAttempt chan failAttemptArgs failPayment chan failPaymentArgs fetchInFlight chan struct{} + sequence uint64 sync.Mutex } @@ -320,9 +322,11 @@ func (m *mockControlTowerOld) InitPayment(_ context.Context, return paymentsdb.ErrPaymentInFlight } + m.sequence++ delete(m.failed, phash) m.payments[phash] = &testPayment{ - info: *c, + sequence: m.sequence, + info: *c, } return nil @@ -531,7 +535,8 @@ func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) ( } mp := &paymentsdb.MPPayment{ - Info: &p.info, + SequenceNum: p.sequence, + Info: &p.info, } reason, ok := m.failed[phash] @@ -834,6 +839,11 @@ type mockMPPayment struct { var _ paymentsdb.DBMPPayment = (*mockMPPayment)(nil) +func (m *mockMPPayment) GetSequenceNum() uint64 { + args := m.Called() + return args.Get(0).(uint64) +} + func (m *mockMPPayment) GetState() *paymentsdb.MPPaymentState { args := m.Called() return args.Get(0).(*paymentsdb.MPPaymentState) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 488df5b796e..ed9ec49f9bf 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -20,9 +20,15 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -// ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but -// the payment lifecycle is exiting . -var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") +var ( + // ErrPaymentLifecycleExiting is used when waiting for htlc attempt + // result, but the payment lifecycle is exiting . + ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") + + // ErrPaymentLifecycleStale is returned when a payment lifecycle detects + // that its payment hash now points to a newer payment sequence number. + ErrPaymentLifecycleStale = errors.New("payment lifecycle stale") +) // switchResult is the result sent back from the switch after processing the // HTLC. @@ -41,6 +47,7 @@ type paymentLifecycle struct { router *ChannelRouter feeLimit lnwire.MilliSatoshi identifier lntypes.Hash + sequenceNum uint64 paySession PaymentSession shardTracker shards.ShardTracker currentHeight int32 @@ -214,6 +221,10 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, return [32]byte{}, nil, err } + if err := p.checkPaymentSequenceNum(payment); err != nil { + return [32]byte{}, nil, err + } + // Get the payment status. status := payment.GetStatus() @@ -359,6 +370,27 @@ lifecycle: return [32]byte{}, nil, *failure } +// checkPaymentSequenceNum binds the lifecycle to its first observed payment +// sequence number, then returns an error if later reloads observe a different +// sequence number for the same payment hash. +func (p *paymentLifecycle) checkPaymentSequenceNum( + payment paymentsdb.DBMPPayment) error { + + sequenceNum := payment.GetSequenceNum() + if p.sequenceNum == 0 { + p.sequenceNum = sequenceNum + return nil + } + + if sequenceNum == p.sequenceNum { + return nil + } + + return fmt.Errorf("%w: payment %v sequence number changed from %d "+ + "to %d", ErrPaymentLifecycleStale, p.identifier, + p.sequenceNum, sequenceNum) +} + // checkContext checks whether the payment context has been canceled. // Cancellation occurs manually or if the context times out. func (p *paymentLifecycle) checkContext(ctx context.Context) error { @@ -1173,6 +1205,9 @@ func (p *paymentLifecycle) reloadPayment( if err != nil { return nil, nil, err } + if err := p.checkPaymentSequenceNum(payment); err != nil { + return nil, nil, err + } ps := payment.GetState() remainingFees := p.calcFeeBudget(ps.FeesPaid) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 564942d327e..e20c4ad1569 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -149,6 +149,7 @@ func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { htlcs := []paymentsdb.HTLCAttempt{} m.payment.On("InFlightHTLCs").Return(htlcs).Once() + m.payment.On("GetSequenceNum").Return(uint64(1)) return p, m } @@ -801,6 +802,29 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) { require.Zero(t, m.collectResultsCount) } +// TestResumePaymentFailOnStaleSequenceNum checks that a lifecycle exits before +// making more attempts when its payment hash has been recycled by a newer +// InitPayment. +func TestResumePaymentFailOnStaleSequenceNum(t *testing.T) { + t.Parallel() + + p, m := setupTestPaymentLifecycle(t) + + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight).Once() + newPayment := &mockMPPayment{} + m.control.On("FetchPayment", p.identifier).Return( + newPayment, nil, + ).Once() + newPayment.On("GetSequenceNum").Return(uint64(2)).Once() + defer newPayment.AssertExpectations(t) + + sendPaymentAndAssertError( + t, t.Context(), p, ErrPaymentLifecycleStale, + ) + + require.Zero(t, m.collectResultsCount) +} + // TestResumePaymentFailOnTimeout checks that when timeout is reached, the // payment is failed. // diff --git a/routing/router.go b/routing/router.go index 37aeef22361..af9e15c74fb 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1145,12 +1145,42 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, FirstHopCustomRecords: firstHopCustomRecords, } - err := r.cfg.Control.InitPayment(ctx, paymentIdentifier, info) + var dbPayment paymentsdb.DBMPPayment + err := r.cfg.Control.InitPayment( + ctx, paymentIdentifier, info, + ) switch { + case err == nil: + payment, err := r.cfg.Control.FetchPayment( + ctx, paymentIdentifier, + ) + if err != nil { + return nil, err + } + + dbPayment = payment + // If this is an MPP attempt and the hash is already registered with // the database, we can go on to launch the shard. case mpp != nil && errors.Is(err, paymentsdb.ErrPaymentInFlight): + payment, err := r.cfg.Control.FetchPayment( + ctx, paymentIdentifier, + ) + if err != nil { + return nil, err + } + + dbPayment = payment + case mpp != nil && errors.Is(err, paymentsdb.ErrPaymentExists): + payment, err := r.cfg.Control.FetchPayment( + ctx, paymentIdentifier, + ) + if err != nil { + return nil, err + } + + dbPayment = payment // Any other error is not tolerated. case err != nil: @@ -1176,6 +1206,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, r, 0, paymentIdentifier, nil, shardTracker, 0, firstHopCustomRecords, ) + if err := p.checkPaymentSequenceNum(dbPayment); err != nil { + return nil, err + } // Allow the traffic shaper to add custom records to the outgoing HTLC // and also adjust the amount if needed. diff --git a/routing/router_test.go b/routing/router_test.go index e14ac199f83..9f953ad00ea 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2220,6 +2220,9 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + payment.On("GetSequenceNum").Return(uint64(1)).Once() controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("SettleAttempt", payHash, mock.Anything, mock.Anything, @@ -2253,6 +2256,7 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrNonMPP checks that an error is return when @@ -2366,6 +2370,9 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + payment.On("GetSequenceNum").Return(uint64(1)).Once() controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("FailAttempt", payHash, mock.Anything, mock.Anything, @@ -2392,6 +2399,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteSkipTempErrPermanentFailure validates a permanent failure @@ -2446,6 +2454,9 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() + payment.On("GetSequenceNum").Return(uint64(1)).Once() controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("FailAttempt", @@ -2478,6 +2489,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { controlTower.AssertExpectations(t) payer.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendToRouteTempFailure validates a temporary failure will cause the @@ -2532,6 +2544,9 @@ func TestSendToRouteTempFailure(t *testing.T) { // Register mockers with the expected method calls. controlTower.On("InitPayment", payHash, mock.Anything).Return(nil) + payment := &mockMPPayment{} + controlTower.On("FetchPayment", payHash).Return(payment, nil).Twice() + payment.On("GetSequenceNum").Return(uint64(1)).Once() controlTower.On("RegisterAttempt", payHash, mock.Anything).Return(nil) controlTower.On("FailAttempt", payHash, mock.Anything, mock.Anything, @@ -2546,10 +2561,6 @@ func TestSendToRouteTempFailure(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, ).Return(tempErr) - // Mock the control tower to return the mocked payment. - payment := &mockMPPayment{} - controlTower.On("FetchPayment", payHash).Return(payment, nil).Once() - // Mock the payment to return nil failure reason. payment.On("TerminalInfo").Return(nil, nil).Once()