diff --git a/pkg/currency/currency.go b/pkg/currency/currency.go new file mode 100644 index 00000000..9951e180 --- /dev/null +++ b/pkg/currency/currency.go @@ -0,0 +1,42 @@ +package currency + +import ( + "errors" + "fmt" + "math" +) + +// Picodollar is a type to represent currency with 12 decimal precision +type PicoDollar int64 + +const ( + PicoDollarsPerDollar = 1e12 +) + +// FromDollars converts a dollar amount (as a float) to Picodollars +// This should mostly be used for testing, and real usage should be done purely in PicoDollars +func FromDollars(dollars float64) (PicoDollar, error) { + if math.IsNaN(dollars) || math.IsInf(dollars, 0) { + return 0, errors.New("invalid dollar amount: must be a finite number") + } + + picodollars := dollars * PicoDollarsPerDollar + if (picodollars < 0 && dollars > 0) || (picodollars > 0 && dollars < 0) { + return 0, errors.New("overflow: dollar amount too large") + } + return PicoDollar(picodollars), nil +} + +// ToDollars converts PicoDollars to a dollar amount (as a float) +func (p PicoDollar) toDollarsTestOnly() float64 { + return float64(p) / PicoDollarsPerDollar +} + +// ToMicroDollars converts PicoDollars to MicroDollars (1e6 units per dollar) +func (p PicoDollar) ToMicroDollars() int64 { + return int64(p / 1e6) +} + +func (p PicoDollar) String() string { + return fmt.Sprintf("%.12f", p.toDollarsTestOnly()) +} diff --git a/pkg/currency/currency_test.go b/pkg/currency/currency_test.go new file mode 100644 index 00000000..cc2ef592 --- /dev/null +++ b/pkg/currency/currency_test.go @@ -0,0 +1,28 @@ +package currency + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConversion(t *testing.T) { + initial, err := FromDollars(1.25) + require.NoError(t, err) + require.Equal(t, PicoDollar(1250000000000), initial) + + converted := initial.toDollarsTestOnly() + require.Equal(t, 1.25, converted) +} + +func TestString(t *testing.T) { + initial, err := FromDollars(1.25) + require.NoError(t, err) + require.Equal(t, "1.250000000000", initial.String()) +} + +func TestToMicroDollars(t *testing.T) { + initial, err := FromDollars(1.25) + require.NoError(t, err) + require.Equal(t, int64(1250000), initial.ToMicroDollars()) +} diff --git a/pkg/fees/calculator.go b/pkg/fees/calculator.go new file mode 100644 index 00000000..c3f35eea --- /dev/null +++ b/pkg/fees/calculator.go @@ -0,0 +1,77 @@ +package fees + +import ( + "fmt" + "time" + + "github.com/xmtp/xmtpd/pkg/currency" +) + +type FeeCalculator struct { + ratesFetcher IRatesFetcher +} + +func NewFeeCalculator(ratesFetcher IRatesFetcher) *FeeCalculator { + return &FeeCalculator{ratesFetcher: ratesFetcher} +} + +func (c *FeeCalculator) CalculateBaseFee( + messageTime time.Time, + messageSize int64, + storageDurationDays int64, +) (currency.PicoDollar, error) { + if messageSize <= 0 { + return 0, fmt.Errorf("messageSize must be greater than 0, got %d", messageSize) + } + if storageDurationDays <= 0 { + return 0, fmt.Errorf( + "storageDurationDays must be greater than 0, got %d", + storageDurationDays, + ) + } + + rates, err := c.ratesFetcher.GetRates(messageTime) + if err != nil { + return 0, err + } + + // Calculate storage fee components separately to check for overflow + storageFeePerByte := rates.StorageFee * currency.PicoDollar(messageSize) + if storageFeePerByte/currency.PicoDollar(messageSize) != rates.StorageFee { + return 0, fmt.Errorf("storage fee calculation overflow") + } + + totalStorageFee := storageFeePerByte * currency.PicoDollar(storageDurationDays) + if totalStorageFee/currency.PicoDollar(storageDurationDays) != storageFeePerByte { + return 0, fmt.Errorf("storage fee calculation overflow") + } + + return rates.MessageFee + totalStorageFee, nil +} + +func (c *FeeCalculator) CalculateCongestionFee( + messageTime time.Time, + congestionUnits int64, +) (currency.PicoDollar, error) { + if congestionUnits < 0 || congestionUnits > 100 { + return 0, fmt.Errorf( + "congestionPercent must be between 0 and 100, got %d", + congestionUnits, + ) + } + + if congestionUnits == 0 { + return 0, nil + } + + rates, err := c.ratesFetcher.GetRates(messageTime) + if err != nil { + return 0, err + } + + result := rates.CongestionFee * currency.PicoDollar(congestionUnits) + if result/currency.PicoDollar(congestionUnits) != rates.CongestionFee { + return 0, fmt.Errorf("congestion fee calculation overflow") + } + return result, nil +} diff --git a/pkg/fees/calculator_test.go b/pkg/fees/calculator_test.go new file mode 100644 index 00000000..60f41a91 --- /dev/null +++ b/pkg/fees/calculator_test.go @@ -0,0 +1,85 @@ +package fees + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/currency" +) + +const ( + RATE_MESSAGE_FEE = 100 + RATE_STORAGE_FEE = 50 + RATE_CONGESTION_FEE = 200 +) + +func setupCalculator() *FeeCalculator { + rates := &Rates{ + MessageFee: RATE_MESSAGE_FEE, + StorageFee: RATE_STORAGE_FEE, + CongestionFee: RATE_CONGESTION_FEE, + } + ratesFetcher := NewFixedRatesFetcher(rates) + return NewFeeCalculator(ratesFetcher) +} + +func TestCalculateBaseFee(t *testing.T) { + calculator := setupCalculator() + + messageTime := time.Now() + messageSize := int64(100) + storageDurationDays := int64(1) + + baseFee, err := calculator.CalculateBaseFee(messageTime, messageSize, storageDurationDays) + require.NoError(t, err) + + expectedFee := RATE_MESSAGE_FEE + (RATE_STORAGE_FEE * messageSize * storageDurationDays) + require.Equal(t, currency.PicoDollar(expectedFee), baseFee) +} + +func TestCalculateCongestionFee(t *testing.T) { + calculator := setupCalculator() + + messageTime := time.Now() + congestionPercent := int64(50) + + congestionFee, err := calculator.CalculateCongestionFee(messageTime, congestionPercent) + require.NoError(t, err) + + expectedFee := RATE_CONGESTION_FEE * congestionPercent + require.Equal(t, currency.PicoDollar(expectedFee), congestionFee) +} + +func TestOverflow(t *testing.T) { + calculator := setupCalculator() + + messageTime := time.Now() + + // Test overflow in CalculateBaseFee + messageSize := math.MaxInt64 + storageDurationDays := math.MaxInt64 + + _, err := calculator.CalculateBaseFee( + messageTime, + int64(messageSize), + int64(storageDurationDays), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "storage fee calculation overflow") +} + +func TestInvalidCongestionPercent(t *testing.T) { + calculator := setupCalculator() + + messageTime := time.Now() + + _, err := calculator.CalculateCongestionFee(messageTime, int64(101)) + require.Error(t, err) + require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") + + _, err = calculator.CalculateCongestionFee(messageTime, int64(-1)) + require.Error(t, err) + require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") +} diff --git a/pkg/fees/fixedRates.go b/pkg/fees/fixedRates.go new file mode 100644 index 00000000..8596ef74 --- /dev/null +++ b/pkg/fees/fixedRates.go @@ -0,0 +1,17 @@ +package fees + +import "time" + +// A fixed fee schedule that doesn't rely on the blockchain. +// Used primarily for testing, and until we have fees onchain. +type FixedRatesFetcher struct { + rates *Rates +} + +func NewFixedRatesFetcher(rates *Rates) *FixedRatesFetcher { + return &FixedRatesFetcher{rates: rates} +} + +func (f *FixedRatesFetcher) GetRates(_messageTime time.Time) (*Rates, error) { + return f.rates, nil +} diff --git a/pkg/fees/interface.go b/pkg/fees/interface.go new file mode 100644 index 00000000..92e4a2d2 --- /dev/null +++ b/pkg/fees/interface.go @@ -0,0 +1,33 @@ +package fees + +import ( + "time" + + "github.com/xmtp/xmtpd/pkg/currency" +) + +// Rates containt the cost for each fee component at a given message time. +// Values in the rates struct are denominated in USD PicoDollars +type Rates struct { + MessageFee currency.PicoDollar // The flat per-message fee + StorageFee currency.PicoDollar // The fee per byte-day of storage + CongestionFee currency.PicoDollar // The fee per unit of congestion +} + +// The RatesFetcher is responsible for loading the rates for a given message time. +// This allows us to roll out new rates over time, and apply them to messages consistently. +type IRatesFetcher interface { + GetRates(messageTime time.Time) (*Rates, error) +} + +type IFeeCalculator interface { + CalculateBaseFee( + messageTime time.Time, + messageSize int64, + storageDurationDays int64, + ) (currency.PicoDollar, error) + CalculateCongestionFee( + messageTime time.Time, + congestionUnits int64, + ) (currency.PicoDollar, error) +}