-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### TL;DR - Introduces a new fee calculation system with currency handling in picodollars (12 decimal precision). - Adds some interfaces for how we can get the current rates from the system. Hardcoded for now, but we'll move to something onchain. ## Notes Handling money in integers feels like the right way to do it, even if we have to convert from PicoDollars to MicroDollars (6 decimals) when we go onchain. We will have to be careful about that conversion to make sure no one can get free messaging by abusing the rounding. ### What changed? - Added a new `currency` package with `PicoDollar` type for precise financial calculations - Created a fee calculation system with components for message, storage, and congestion fees - Implemented a rates fetcher interface with a fixed-rate implementation - Added currency conversion utilities between dollars, picodollars, and microdollars - Included comprehensive test coverage for currency conversions ### How to test? 1. Run the currency package tests: ```go go test ./pkg/currency ``` 2. Verify currency conversions: - Test dollar to picodollar conversion - Validate string representation - Check microdollar conversion 3. Test fee calculations with the fixed rate fetcher implementation ### Why make this change? To provide a precise and reliable fee calculation system that: - Handles currency with high precision (12 decimal places) - Supports different fee components (message, storage, congestion) - Allows for future rate adjustments through the rates fetcher interface - Prevents floating-point arithmetic errors in financial calculations <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a high-precision currency module that supports conversions between standard monetary units and finer subdivisions. - Launched a fee calculation module that computes base and congestion fees based on message parameters and dynamic rate fetching, with robust error handling. - **Tests** - Added comprehensive tests to validate currency conversions and fee calculations, including checks for overflow and invalid input scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
Showing
6 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package currency | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"math" | ||
) | ||
|
||
// Picodollar is a type to represent currency with 12 decimal precision | ||
type PicoDollar int64 | ||
|
||
const ( | ||
PicoDollarsPerDollar = 1e12 | ||
) | ||
|
||
// FromDollars converts a dollar amount (as a float) to Picodollars | ||
// This should mostly be used for testing, and real usage should be done purely in PicoDollars | ||
func FromDollars(dollars float64) (PicoDollar, error) { | ||
if math.IsNaN(dollars) || math.IsInf(dollars, 0) { | ||
return 0, errors.New("invalid dollar amount: must be a finite number") | ||
} | ||
|
||
picodollars := dollars * PicoDollarsPerDollar | ||
if (picodollars < 0 && dollars > 0) || (picodollars > 0 && dollars < 0) { | ||
return 0, errors.New("overflow: dollar amount too large") | ||
} | ||
return PicoDollar(picodollars), nil | ||
} | ||
|
||
// ToDollars converts PicoDollars to a dollar amount (as a float) | ||
func (p PicoDollar) toDollarsTestOnly() float64 { | ||
return float64(p) / PicoDollarsPerDollar | ||
} | ||
|
||
// ToMicroDollars converts PicoDollars to MicroDollars (1e6 units per dollar) | ||
func (p PicoDollar) ToMicroDollars() int64 { | ||
return int64(p / 1e6) | ||
} | ||
|
||
func (p PicoDollar) String() string { | ||
return fmt.Sprintf("%.12f", p.toDollarsTestOnly()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package currency | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestConversion(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, PicoDollar(1250000000000), initial) | ||
|
||
converted := initial.toDollarsTestOnly() | ||
require.Equal(t, 1.25, converted) | ||
} | ||
|
||
func TestString(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, "1.250000000000", initial.String()) | ||
} | ||
|
||
func TestToMicroDollars(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, int64(1250000), initial.ToMicroDollars()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package fees | ||
|
||
import ( | ||
"fmt" | ||
"time" | ||
|
||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
type FeeCalculator struct { | ||
ratesFetcher IRatesFetcher | ||
} | ||
|
||
func NewFeeCalculator(ratesFetcher IRatesFetcher) *FeeCalculator { | ||
return &FeeCalculator{ratesFetcher: ratesFetcher} | ||
} | ||
|
||
func (c *FeeCalculator) CalculateBaseFee( | ||
messageTime time.Time, | ||
messageSize int64, | ||
storageDurationDays int64, | ||
) (currency.PicoDollar, error) { | ||
if messageSize <= 0 { | ||
return 0, fmt.Errorf("messageSize must be greater than 0, got %d", messageSize) | ||
} | ||
if storageDurationDays <= 0 { | ||
return 0, fmt.Errorf( | ||
"storageDurationDays must be greater than 0, got %d", | ||
storageDurationDays, | ||
) | ||
} | ||
|
||
rates, err := c.ratesFetcher.GetRates(messageTime) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
// Calculate storage fee components separately to check for overflow | ||
storageFeePerByte := rates.StorageFee * currency.PicoDollar(messageSize) | ||
if storageFeePerByte/currency.PicoDollar(messageSize) != rates.StorageFee { | ||
return 0, fmt.Errorf("storage fee calculation overflow") | ||
} | ||
|
||
totalStorageFee := storageFeePerByte * currency.PicoDollar(storageDurationDays) | ||
if totalStorageFee/currency.PicoDollar(storageDurationDays) != storageFeePerByte { | ||
return 0, fmt.Errorf("storage fee calculation overflow") | ||
} | ||
|
||
return rates.MessageFee + totalStorageFee, nil | ||
} | ||
|
||
func (c *FeeCalculator) CalculateCongestionFee( | ||
messageTime time.Time, | ||
congestionUnits int64, | ||
) (currency.PicoDollar, error) { | ||
if congestionUnits < 0 || congestionUnits > 100 { | ||
return 0, fmt.Errorf( | ||
"congestionPercent must be between 0 and 100, got %d", | ||
congestionUnits, | ||
) | ||
} | ||
|
||
if congestionUnits == 0 { | ||
return 0, nil | ||
} | ||
|
||
rates, err := c.ratesFetcher.GetRates(messageTime) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
result := rates.CongestionFee * currency.PicoDollar(congestionUnits) | ||
if result/currency.PicoDollar(congestionUnits) != rates.CongestionFee { | ||
return 0, fmt.Errorf("congestion fee calculation overflow") | ||
} | ||
return result, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package fees | ||
|
||
import ( | ||
"math" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
const ( | ||
RATE_MESSAGE_FEE = 100 | ||
RATE_STORAGE_FEE = 50 | ||
RATE_CONGESTION_FEE = 200 | ||
) | ||
|
||
func setupCalculator() *FeeCalculator { | ||
rates := &Rates{ | ||
MessageFee: RATE_MESSAGE_FEE, | ||
StorageFee: RATE_STORAGE_FEE, | ||
CongestionFee: RATE_CONGESTION_FEE, | ||
} | ||
ratesFetcher := NewFixedRatesFetcher(rates) | ||
return NewFeeCalculator(ratesFetcher) | ||
} | ||
|
||
func TestCalculateBaseFee(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
messageSize := int64(100) | ||
storageDurationDays := int64(1) | ||
|
||
baseFee, err := calculator.CalculateBaseFee(messageTime, messageSize, storageDurationDays) | ||
require.NoError(t, err) | ||
|
||
expectedFee := RATE_MESSAGE_FEE + (RATE_STORAGE_FEE * messageSize * storageDurationDays) | ||
require.Equal(t, currency.PicoDollar(expectedFee), baseFee) | ||
} | ||
|
||
func TestCalculateCongestionFee(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
congestionPercent := int64(50) | ||
|
||
congestionFee, err := calculator.CalculateCongestionFee(messageTime, congestionPercent) | ||
require.NoError(t, err) | ||
|
||
expectedFee := RATE_CONGESTION_FEE * congestionPercent | ||
require.Equal(t, currency.PicoDollar(expectedFee), congestionFee) | ||
} | ||
|
||
func TestOverflow(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
|
||
// Test overflow in CalculateBaseFee | ||
messageSize := math.MaxInt64 | ||
storageDurationDays := math.MaxInt64 | ||
|
||
_, err := calculator.CalculateBaseFee( | ||
messageTime, | ||
int64(messageSize), | ||
int64(storageDurationDays), | ||
) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "storage fee calculation overflow") | ||
} | ||
|
||
func TestInvalidCongestionPercent(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
|
||
_, err := calculator.CalculateCongestionFee(messageTime, int64(101)) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") | ||
|
||
_, err = calculator.CalculateCongestionFee(messageTime, int64(-1)) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package fees | ||
|
||
import "time" | ||
|
||
// A fixed fee schedule that doesn't rely on the blockchain. | ||
// Used primarily for testing, and until we have fees onchain. | ||
type FixedRatesFetcher struct { | ||
rates *Rates | ||
} | ||
|
||
func NewFixedRatesFetcher(rates *Rates) *FixedRatesFetcher { | ||
return &FixedRatesFetcher{rates: rates} | ||
} | ||
|
||
func (f *FixedRatesFetcher) GetRates(_messageTime time.Time) (*Rates, error) { | ||
return f.rates, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package fees | ||
|
||
import ( | ||
"time" | ||
|
||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
// Rates containt the cost for each fee component at a given message time. | ||
// Values in the rates struct are denominated in USD PicoDollars | ||
type Rates struct { | ||
MessageFee currency.PicoDollar // The flat per-message fee | ||
StorageFee currency.PicoDollar // The fee per byte-day of storage | ||
CongestionFee currency.PicoDollar // The fee per unit of congestion | ||
} | ||
|
||
// The RatesFetcher is responsible for loading the rates for a given message time. | ||
// This allows us to roll out new rates over time, and apply them to messages consistently. | ||
type IRatesFetcher interface { | ||
GetRates(messageTime time.Time) (*Rates, error) | ||
} | ||
|
||
type IFeeCalculator interface { | ||
CalculateBaseFee( | ||
messageTime time.Time, | ||
messageSize int64, | ||
storageDurationDays int64, | ||
) (currency.PicoDollar, error) | ||
CalculateCongestionFee( | ||
messageTime time.Time, | ||
congestionUnits int64, | ||
) (currency.PicoDollar, error) | ||
} |