Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions utils/math/safe_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ package math

import (
"errors"
"math/big"

"golang.org/x/exp/constraints"
)

var (
ErrOverflow = errors.New("overflow")
ErrUnderflow = errors.New("underflow")
ErrOverflow = errors.New("overflow")
ErrUnderflow = errors.New("underflow")
errDivideByZero = errors.New("divide by zero")

// Deprecated: Add64 is deprecated. Use Add[uint64] instead.
Add64 = Add[uint64]
Expand Down Expand Up @@ -58,3 +60,41 @@ func Mul[T constraints.Unsigned](a, b T) (T, error) {
func AbsDiff[T constraints.Unsigned](a, b T) T {
return max(a, b) - min(a, b)
}

// MulDiv computes (a * b) / c with full precision using big.Int arithmetic.
// The result is rounded to the nearest integer.
// Returns errDivideByZero if c is zero, or ErrOverflow if the result exceeds uint64.
func MulDiv(a, b, c uint64) (uint64, error) {
if c == 0 {
return 0, errDivideByZero
}

bigA := new(big.Int).SetUint64(a)
bigB := new(big.Int).SetUint64(b)
bigC := new(big.Int).SetUint64(c)

result := new(big.Int).Mul(bigA, bigB)
result = divRound(result, bigC)

if !result.IsUint64() {
return 0, ErrOverflow
}
return result.Uint64(), nil
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bigA := new(big.Int).SetUint64(a)
bigB := new(big.Int).SetUint64(b)
bigC := new(big.Int).SetUint64(c)
result := new(big.Int).Mul(bigA, bigB)
result = divRound(result, bigC)
if !result.IsUint64() {
return 0, ErrOverflow
}
return result.Uint64(), nil
hi, lo := bits.Mul64(a, b)
if c <= hi {
return 0, ErrOverflow
}
quo, rem := bits.Div64(hi, lo, c)
if rem < (1<<63) && 2*rem < c {
return quo, nil
}
if quo == math.MaxUint64 {
return 0, ErrOverflow
}
return quo + 1, nil

This doesn't require using big.Int as the bits package provides full-precision alternatives. The divRound() function becomes redundant too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion! For now, I will update with your approach and we can still consider the way we calculate it as an open discussion if there are better ideas out there.

}

// divRound divides a by b and rounds to the nearest integer.
// Note: This function uses big.Int.DivMod, which has sign-dependent behavior.
func divRound(a, b *big.Int) *big.Int {
quotient := new(big.Int)
remainder := new(big.Int)

quotient.DivMod(a, b, remainder)

// if 2*remainder >= b → round up
doubleRem := new(big.Int).Mul(remainder, big.NewInt(2))
if doubleRem.Cmp(b) >= 0 {
quotient.Add(quotient, big.NewInt(1))
}

return quotient
}
108 changes: 108 additions & 0 deletions utils/math/safe_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,111 @@ func TestAbsDiff(t *testing.T) {
require.Zero(AbsDiff(uint64(1), uint64(1)))
require.Zero(AbsDiff(uint64(0), uint64(0)))
}

func TestMulDiv(t *testing.T) {
tests := []struct {
name string
a uint64
b uint64
c uint64
want uint64
wantErr error
}{
{
name: "division by zero",
a: 100,
b: 5,
c: 0,
wantErr: errDivideByZero,
},
{
name: "a is zero",
a: 0,
b: 4,
c: 50,
want: 0,
},
{
name: "b is zero",
a: 250,
b: 0,
c: 50,
want: 0,
},
{
name: "basic case 1",
a: 100,
b: 3,
c: 10,
want: 30,
},
{
name: "basic case 2",
a: 250,
b: 4,
c: 50,
want: 20,
},
{
name: "precision",
a: 7,
b: 3,
c: 10,
want: 2,
},
{
name: "overflow",
a: maxUint64,
b: 10,
c: 2,
wantErr: ErrOverflow,
},
{
name: "round down",
a: 10,
b: 10,
c: 30,
want: 3,
},
{
name: "round up",
a: 20,
b: 10,
c: 30,
want: 7,
},
{
name: "large values without overflow",
a: 300_000_000_000,
b: 200_000_000_000,
c: 400_000_000_000,
want: 150_000_000_000,
},
{
name: "small a large c",
a: 5,
b: 3,
c: 10,
want: 2,
},
{
name: "maxUint64 * maxUint64 / maxUint64",
a: maxUint64,
b: maxUint64,
c: maxUint64,
want: maxUint64,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MulDiv(tt.a, tt.b, tt.c)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}
8 changes: 8 additions & 0 deletions vms/platformvm/api/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ type PermissionlessValidator struct {
Staked []UTXO `json:"staked,omitempty"`
Signer *signer.ProofOfPossession `json:"signer,omitempty"`

// ACP-236
// The owner who can modify auto-renewed validator config, if applicable.
ConfigOwner *Owner `json:"configOwner,omitempty"`
// The validation cycle duration in seconds, if applicable.
Period *json.Uint64 `json:"period,omitempty"`
// Percentage of rewards to auto-compound, if applicable.
AutoCompoundRewardShares *json.Uint32 `json:"autoCompoundRewardShares,omitempty"`

// The delegators delegating to this validator
DelegatorCount *json.Uint64 `json:"delegatorCount,omitempty"`
DelegatorWeight *json.Uint64 `json:"delegatorWeight,omitempty"`
Expand Down
26 changes: 24 additions & 2 deletions vms/platformvm/block/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,12 @@ func buildBlock(
return nil, fmt.Errorf("could not find next staker to reward: %w", err)
}
if shouldReward {
rewardValidatorTx, err := NewRewardValidatorTx(builder.txExecutorBackend.Ctx, stakerTxID)
stakerTx, _, err := parentState.GetTx(stakerTxID)
Comment thread
yacovm marked this conversation as resolved.
if err != nil {
return nil, err
}

rewardValidatorTx, err := newRewardTxForStaker(builder.txExecutorBackend.Ctx, stakerTx, timestamp)
if err != nil {
return nil, fmt.Errorf("could not build tx to reward staker: %w", err)
}
Expand Down Expand Up @@ -608,7 +613,7 @@ func executeTx(
}

// getNextStakerToReward returns the next staker txID to remove from the staking
// set with a RewardValidatorTx rather than an AdvanceTimeTx. [chainTimestamp]
// set with a RewardValidatorTx/RewardAutoRenewedValidatorTx rather than an AdvanceTimeTx. [chainTimestamp]
// is the timestamp of the chain at the time this validator would be getting
// removed and is used to calculate [shouldReward].
// Returns:
Expand Down Expand Up @@ -650,3 +655,20 @@ func NewRewardValidatorTx(ctx *snow.Context, txID ids.ID) (*txs.Tx, error) {
}
return tx, tx.SyntacticVerify(ctx)
}

func newRewardTxForStaker(ctx *snow.Context, stakerTx *txs.Tx, timestamp time.Time) (*txs.Tx, error) {
if _, ok := stakerTx.Unsigned.(*txs.AddAutoRenewedValidatorTx); ok {
return newRewardAutoRenewedValidatorTx(ctx, stakerTx.ID(), uint64(timestamp.Unix()))
}

return NewRewardValidatorTx(ctx, stakerTx.ID())
}

func newRewardAutoRenewedValidatorTx(ctx *snow.Context, txID ids.ID, timestamp uint64) (*txs.Tx, error) {
utx := &txs.RewardAutoRenewedValidatorTx{TxID: txID, Timestamp: timestamp}
tx, err := txs.NewSigned(utx, txs.Codec, nil)
if err != nil {
return nil, err
}
return tx, tx.SyntacticVerify(ctx)
}
Loading
Loading