Skip to content

Commit

Permalink
Add fee calculation helpers (#537)
Browse files Browse the repository at this point in the history
### TL;DR
- Introduces a new fee calculation system with currency handling in picodollars (12 decimal precision).
- Adds some interfaces for how we can get the current rates from the system. Hardcoded for now, but we'll move to something onchain.

## Notes

Handling money in integers feels like the right way to do it, even if we have to convert from PicoDollars to MicroDollars (6 decimals) when we go onchain.

We will have to be careful about that conversion to make sure no one can get free messaging by abusing the rounding.

### What changed?
- Added a new `currency` package with `PicoDollar` type for precise financial calculations
- Created a fee calculation system with components for message, storage, and congestion fees
- Implemented a rates fetcher interface with a fixed-rate implementation
- Added currency conversion utilities between dollars, picodollars, and microdollars
- Included comprehensive test coverage for currency conversions

### How to test?
1. Run the currency package tests:
```go
go test ./pkg/currency
```
2. Verify currency conversions:
   - Test dollar to picodollar conversion
   - Validate string representation
   - Check microdollar conversion
3. Test fee calculations with the fixed rate fetcher implementation

### Why make this change?
To provide a precise and reliable fee calculation system that:
- Handles currency with high precision (12 decimal places)
- Supports different fee components (message, storage, congestion)
- Allows for future rate adjustments through the rates fetcher interface
- Prevents floating-point arithmetic errors in financial calculations

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

- **New Features**
  - Introduced a high-precision currency module that supports conversions between standard monetary units and finer subdivisions.
  - Launched a fee calculation module that computes base and congestion fees based on message parameters and dynamic rate fetching, with robust error handling.

- **Tests**
  - Added comprehensive tests to validate currency conversions and fee calculations, including checks for overflow and invalid input scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
neekolas authored Feb 26, 2025
1 parent ae46bab commit e296a86
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 0 deletions.
42 changes: 42 additions & 0 deletions pkg/currency/currency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package currency

import (
"errors"
"fmt"
"math"
)

// Picodollar is a type to represent currency with 12 decimal precision
type PicoDollar int64

const (
PicoDollarsPerDollar = 1e12
)

// FromDollars converts a dollar amount (as a float) to Picodollars
// This should mostly be used for testing, and real usage should be done purely in PicoDollars
func FromDollars(dollars float64) (PicoDollar, error) {
if math.IsNaN(dollars) || math.IsInf(dollars, 0) {
return 0, errors.New("invalid dollar amount: must be a finite number")
}

picodollars := dollars * PicoDollarsPerDollar
if (picodollars < 0 && dollars > 0) || (picodollars > 0 && dollars < 0) {
return 0, errors.New("overflow: dollar amount too large")
}
return PicoDollar(picodollars), nil
}

// ToDollars converts PicoDollars to a dollar amount (as a float)
func (p PicoDollar) toDollarsTestOnly() float64 {
return float64(p) / PicoDollarsPerDollar
}

// ToMicroDollars converts PicoDollars to MicroDollars (1e6 units per dollar)
func (p PicoDollar) ToMicroDollars() int64 {
return int64(p / 1e6)
}

func (p PicoDollar) String() string {
return fmt.Sprintf("%.12f", p.toDollarsTestOnly())
}
28 changes: 28 additions & 0 deletions pkg/currency/currency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package currency

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestConversion(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, PicoDollar(1250000000000), initial)

converted := initial.toDollarsTestOnly()
require.Equal(t, 1.25, converted)
}

func TestString(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, "1.250000000000", initial.String())
}

func TestToMicroDollars(t *testing.T) {
initial, err := FromDollars(1.25)
require.NoError(t, err)
require.Equal(t, int64(1250000), initial.ToMicroDollars())
}
77 changes: 77 additions & 0 deletions pkg/fees/calculator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package fees

import (
"fmt"
"time"

"github.com/xmtp/xmtpd/pkg/currency"
)

type FeeCalculator struct {
ratesFetcher IRatesFetcher
}

func NewFeeCalculator(ratesFetcher IRatesFetcher) *FeeCalculator {
return &FeeCalculator{ratesFetcher: ratesFetcher}
}

func (c *FeeCalculator) CalculateBaseFee(
messageTime time.Time,
messageSize int64,
storageDurationDays int64,
) (currency.PicoDollar, error) {
if messageSize <= 0 {
return 0, fmt.Errorf("messageSize must be greater than 0, got %d", messageSize)
}
if storageDurationDays <= 0 {
return 0, fmt.Errorf(
"storageDurationDays must be greater than 0, got %d",
storageDurationDays,
)
}

rates, err := c.ratesFetcher.GetRates(messageTime)
if err != nil {
return 0, err
}

// Calculate storage fee components separately to check for overflow
storageFeePerByte := rates.StorageFee * currency.PicoDollar(messageSize)
if storageFeePerByte/currency.PicoDollar(messageSize) != rates.StorageFee {
return 0, fmt.Errorf("storage fee calculation overflow")
}

totalStorageFee := storageFeePerByte * currency.PicoDollar(storageDurationDays)
if totalStorageFee/currency.PicoDollar(storageDurationDays) != storageFeePerByte {
return 0, fmt.Errorf("storage fee calculation overflow")
}

return rates.MessageFee + totalStorageFee, nil
}

func (c *FeeCalculator) CalculateCongestionFee(
messageTime time.Time,
congestionUnits int64,
) (currency.PicoDollar, error) {
if congestionUnits < 0 || congestionUnits > 100 {
return 0, fmt.Errorf(
"congestionPercent must be between 0 and 100, got %d",
congestionUnits,
)
}

if congestionUnits == 0 {
return 0, nil
}

rates, err := c.ratesFetcher.GetRates(messageTime)
if err != nil {
return 0, err
}

result := rates.CongestionFee * currency.PicoDollar(congestionUnits)
if result/currency.PicoDollar(congestionUnits) != rates.CongestionFee {
return 0, fmt.Errorf("congestion fee calculation overflow")
}
return result, nil
}
85 changes: 85 additions & 0 deletions pkg/fees/calculator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package fees

import (
"math"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/currency"
)

const (
RATE_MESSAGE_FEE = 100
RATE_STORAGE_FEE = 50
RATE_CONGESTION_FEE = 200
)

func setupCalculator() *FeeCalculator {
rates := &Rates{
MessageFee: RATE_MESSAGE_FEE,
StorageFee: RATE_STORAGE_FEE,
CongestionFee: RATE_CONGESTION_FEE,
}
ratesFetcher := NewFixedRatesFetcher(rates)
return NewFeeCalculator(ratesFetcher)
}

func TestCalculateBaseFee(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()
messageSize := int64(100)
storageDurationDays := int64(1)

baseFee, err := calculator.CalculateBaseFee(messageTime, messageSize, storageDurationDays)
require.NoError(t, err)

expectedFee := RATE_MESSAGE_FEE + (RATE_STORAGE_FEE * messageSize * storageDurationDays)
require.Equal(t, currency.PicoDollar(expectedFee), baseFee)
}

func TestCalculateCongestionFee(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()
congestionPercent := int64(50)

congestionFee, err := calculator.CalculateCongestionFee(messageTime, congestionPercent)
require.NoError(t, err)

expectedFee := RATE_CONGESTION_FEE * congestionPercent
require.Equal(t, currency.PicoDollar(expectedFee), congestionFee)
}

func TestOverflow(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()

// Test overflow in CalculateBaseFee
messageSize := math.MaxInt64
storageDurationDays := math.MaxInt64

_, err := calculator.CalculateBaseFee(
messageTime,
int64(messageSize),
int64(storageDurationDays),
)
require.Error(t, err)
require.Contains(t, err.Error(), "storage fee calculation overflow")
}

func TestInvalidCongestionPercent(t *testing.T) {
calculator := setupCalculator()

messageTime := time.Now()

_, err := calculator.CalculateCongestionFee(messageTime, int64(101))
require.Error(t, err)
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100")

_, err = calculator.CalculateCongestionFee(messageTime, int64(-1))
require.Error(t, err)
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100")
}
17 changes: 17 additions & 0 deletions pkg/fees/fixedRates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package fees

import "time"

// A fixed fee schedule that doesn't rely on the blockchain.
// Used primarily for testing, and until we have fees onchain.
type FixedRatesFetcher struct {
rates *Rates
}

func NewFixedRatesFetcher(rates *Rates) *FixedRatesFetcher {
return &FixedRatesFetcher{rates: rates}
}

func (f *FixedRatesFetcher) GetRates(_messageTime time.Time) (*Rates, error) {
return f.rates, nil
}
33 changes: 33 additions & 0 deletions pkg/fees/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package fees

import (
"time"

"github.com/xmtp/xmtpd/pkg/currency"
)

// Rates containt the cost for each fee component at a given message time.
// Values in the rates struct are denominated in USD PicoDollars
type Rates struct {
MessageFee currency.PicoDollar // The flat per-message fee
StorageFee currency.PicoDollar // The fee per byte-day of storage
CongestionFee currency.PicoDollar // The fee per unit of congestion
}

// The RatesFetcher is responsible for loading the rates for a given message time.
// This allows us to roll out new rates over time, and apply them to messages consistently.
type IRatesFetcher interface {
GetRates(messageTime time.Time) (*Rates, error)
}

type IFeeCalculator interface {
CalculateBaseFee(
messageTime time.Time,
messageSize int64,
storageDurationDays int64,
) (currency.PicoDollar, error)
CalculateCongestionFee(
messageTime time.Time,
congestionUnits int64,
) (currency.PicoDollar, error)
}

0 comments on commit e296a86

Please sign in to comment.