Skip to content

Commit

Permalink
Add fee calculation helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Feb 26, 2025
1 parent 8d666c7 commit 77b36c3
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 0 deletions.
42 changes: 42 additions & 0 deletions pkg/currency/currency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package currency

import (
"errors"
"fmt"
"math"
)

// Picodollar is a type to represent currency with 12 decimal precision
type PicoDollar int64

const (
PicoDollarsPerDollar = 1e12
)

// FromDollars converts a dollar amount (as a float) to Picodollars
// This should mostly be used for testing, and real usage should be done purely in PicoDollars
func FromDollars(dollars float64) (PicoDollar, error) {
if math.IsNaN(dollars) || math.IsInf(dollars, 0) {
return 0, errors.New("invalid dollar amount: must be a finite number")
}

picodollars := dollars * PicoDollarsPerDollar
if (picodollars < 0 && dollars > 0) || (picodollars > 0 && dollars < 0) {
return 0, errors.New("overflow: dollar amount too large")
}
return PicoDollar(picodollars), nil
}

// ToDollars converts PicoDollars to a dollar amount (as a float)
func (p PicoDollar) toDollarsTestOnly() float64 {
return float64(p) / PicoDollarsPerDollar
}

// ToMicroDollars converts PicoDollars to MicroDollars (1e6 units per dollar)
func (p PicoDollar) ToMicroDollars() int64 {
return int64(p / 1e6)
}

func (p PicoDollar) String() string {
return fmt.Sprintf("%.12f", p.toDollarsTestOnly())
}
28 changes: 28 additions & 0 deletions pkg/currency/currency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package currency

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestConversion(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, PicoDollar(1250000000000), initial)

converted := initial.toDollarsTestOnly()
require.Equal(t, 1.25, converted)
}

func TestString(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, "1.250000000000", initial.String())
}

func TestToMicroDollars(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, int64(1250000), initial.ToMicroDollars())
}
77 changes: 77 additions & 0 deletions pkg/fees/calculator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package fees

import (
"fmt"
"time"

"github.com/xmtp/xmtpd/pkg/currency"
)

type FeeCalculator struct {
ratesFetcher IRatesFetcher
}

func NewFeeCalculator(ratesFetcher IRatesFetcher) *FeeCalculator {
return &FeeCalculator{ratesFetcher: ratesFetcher}
}

func (c *FeeCalculator) CalculateBaseFee(
messageTime time.Time,
messageSize int64,
storageDurationDays int64,
) (currency.PicoDollar, error) {
if messageSize <= 0 {
return 0, fmt.Errorf("messageSize must be greater than 0, got %d", messageSize)
}
if storageDurationDays <= 0 {
return 0, fmt.Errorf(
"storageDurationDays must be greater than 0, got %d",
storageDurationDays,
)
}

rates, err := c.ratesFetcher.GetRates(messageTime)
if err != nil {
return 0, err
}

// Calculate storage fee components separately to check for overflow
storageFeePerByte := rates.StorageFee * currency.PicoDollar(messageSize)
if storageFeePerByte/currency.PicoDollar(messageSize) != rates.StorageFee {
return 0, fmt.Errorf("storage fee calculation overflow")
}

totalStorageFee := storageFeePerByte * currency.PicoDollar(storageDurationDays)
if totalStorageFee/currency.PicoDollar(storageDurationDays) != storageFeePerByte {
return 0, fmt.Errorf("storage fee calculation overflow")
}

return rates.MessageFee + totalStorageFee, nil
}

func (c *FeeCalculator) CalculateCongestionFee(
messageTime time.Time,
congestionUnits int64,
) (currency.PicoDollar, error) {
if congestionUnits < 0 || congestionUnits > 100 {
return 0, fmt.Errorf(
"congestionPercent must be between 0 and 100, got %d",
congestionUnits,
)
}

if congestionUnits == 0 {
return 0, nil
}

rates, err := c.ratesFetcher.GetRates(messageTime)
if err != nil {
return 0, err
}

result := rates.CongestionFee * currency.PicoDollar(congestionUnits)
if result/currency.PicoDollar(congestionUnits) != rates.CongestionFee {
return 0, fmt.Errorf("congestion fee calculation overflow")
}
return result, nil
}
85 changes: 85 additions & 0 deletions pkg/fees/calculator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package fees

import (
"math"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/currency"
)

const (
RATE_MESSAGE_FEE = 100
RATE_STORAGE_FEE = 50
RATE_CONGESTION_FEE = 200
)

func setupCalculator() *FeeCalculator {
rates := &Rates{
MessageFee: RATE_MESSAGE_FEE,
StorageFee: RATE_STORAGE_FEE,
CongestionFee: RATE_CONGESTION_FEE,
}
ratesFetcher := NewFixedRatesFetcher(rates)
return NewFeeCalculator(ratesFetcher)
}

func TestCalculateBaseFee(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()
messageSize := int64(100)
storageDurationDays := int64(1)

baseFee, err := calculator.CalculateBaseFee(messageTime, messageSize, storageDurationDays)
require.NoError(t, err)

expectedFee := RATE_MESSAGE_FEE + (RATE_STORAGE_FEE * messageSize * storageDurationDays)
require.Equal(t, currency.PicoDollar(expectedFee), baseFee)
}

func TestCalculateCongestionFee(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()
congestionPercent := int64(50)

congestionFee, err := calculator.CalculateCongestionFee(messageTime, congestionPercent)
require.NoError(t, err)

expectedFee := RATE_CONGESTION_FEE * congestionPercent
require.Equal(t, currency.PicoDollar(expectedFee), congestionFee)
}

func TestOverflow(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()

// Test overflow in CalculateBaseFee
messageSize := math.MaxInt64
storageDurationDays := math.MaxInt64

_, err := calculator.CalculateBaseFee(
messageTime,
int64(messageSize),
int64(storageDurationDays),
)
require.Error(t, err)
require.Contains(t, err.Error(), "storage fee calculation overflow")
}

func TestInvalidCongestionPercent(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()

_, err := calculator.CalculateCongestionFee(messageTime, int64(101))
require.Error(t, err)
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100")

_, err = calculator.CalculateCongestionFee(messageTime, int64(-1))
require.Error(t, err)
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100")
}
17 changes: 17 additions & 0 deletions pkg/fees/fixedRates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package fees

import "time"

// A fixed fee schedule that doesn't rely on the blockchain.
// Used primarily for testing, and until we have fees onchain.
type FixedRatesFetcher struct {
rates *Rates
}

func NewFixedRatesFetcher(rates *Rates) *FixedRatesFetcher {
return &FixedRatesFetcher{rates: rates}
}

func (f *FixedRatesFetcher) GetRates(_messageTime time.Time) (*Rates, error) {
return f.rates, nil
}
33 changes: 33 additions & 0 deletions pkg/fees/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package fees

import (
"time"

"github.com/xmtp/xmtpd/pkg/currency"
)

// Rates containt the cost for each fee component at a given message time.
// Values in the rates struct are denominated in USD PicoDollars
type Rates struct {
MessageFee currency.PicoDollar // The flat per-message fee
StorageFee currency.PicoDollar // The fee per byte-day of storage
CongestionFee currency.PicoDollar // The fee per unit of congestion
}

// The RatesFetcher is responsible for loading the rates for a given message time.
// This allows us to roll out new rates over time, and apply them to messages consistently.
type IRatesFetcher interface {
GetRates(messageTime time.Time) (*Rates, error)
}

type IFeeCalculator interface {
CalculateBaseFee(
messageTime time.Time,
messageSize int64,
storageDurationDays int64,
) (currency.PicoDollar, error)
CalculateCongestionFee(
messageTime time.Time,
congestionUnits int64,
) (currency.PicoDollar, error)
}

0 comments on commit 77b36c3

Please sign in to comment.