-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package currency | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"math" | ||
) | ||
|
||
// Picodollar is a type to represent currency with 12 decimal precision | ||
type PicoDollar int64 | ||
|
||
const ( | ||
PicoDollarsPerDollar = 1e12 | ||
) | ||
|
||
// FromDollars converts a dollar amount (as a float) to Picodollars | ||
// This should mostly be used for testing, and real usage should be done purely in PicoDollars | ||
func FromDollars(dollars float64) (PicoDollar, error) { | ||
if math.IsNaN(dollars) || math.IsInf(dollars, 0) { | ||
return 0, errors.New("invalid dollar amount: must be a finite number") | ||
} | ||
|
||
picodollars := dollars * PicoDollarsPerDollar | ||
if (picodollars < 0 && dollars > 0) || (picodollars > 0 && dollars < 0) { | ||
return 0, errors.New("overflow: dollar amount too large") | ||
} | ||
return PicoDollar(picodollars), nil | ||
} | ||
|
||
// ToDollars converts PicoDollars to a dollar amount (as a float) | ||
func (p PicoDollar) toDollarsTestOnly() float64 { | ||
return float64(p) / PicoDollarsPerDollar | ||
} | ||
|
||
// ToMicroDollars converts PicoDollars to MicroDollars (1e6 units per dollar) | ||
func (p PicoDollar) ToMicroDollars() int64 { | ||
return int64(p / 1e6) | ||
} | ||
|
||
func (p PicoDollar) String() string { | ||
return fmt.Sprintf("%.12f", p.toDollarsTestOnly()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package currency | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestConversion(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, PicoDollar(1250000000000), initial) | ||
|
||
converted := initial.toDollarsTestOnly() | ||
require.Equal(t, 1.25, converted) | ||
} | ||
|
||
func TestString(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, "1.250000000000", initial.String()) | ||
} | ||
|
||
func TestToMicroDollars(t *testing.T) { | ||
initial, err := FromDollars(1.25) | ||
require.NoError(t, err) | ||
require.Equal(t, int64(1250000), initial.ToMicroDollars()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package fees | ||
|
||
import ( | ||
"fmt" | ||
"time" | ||
|
||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
type FeeCalculator struct { | ||
ratesFetcher IRatesFetcher | ||
} | ||
|
||
func NewFeeCalculator(ratesFetcher IRatesFetcher) *FeeCalculator { | ||
return &FeeCalculator{ratesFetcher: ratesFetcher} | ||
} | ||
|
||
func (c *FeeCalculator) CalculateBaseFee( | ||
messageTime time.Time, | ||
messageSize int64, | ||
storageDurationDays int64, | ||
) (currency.PicoDollar, error) { | ||
if messageSize <= 0 { | ||
return 0, fmt.Errorf("messageSize must be greater than 0, got %d", messageSize) | ||
} | ||
if storageDurationDays <= 0 { | ||
return 0, fmt.Errorf( | ||
"storageDurationDays must be greater than 0, got %d", | ||
storageDurationDays, | ||
) | ||
} | ||
|
||
rates, err := c.ratesFetcher.GetRates(messageTime) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
// Calculate storage fee components separately to check for overflow | ||
storageFeePerByte := rates.StorageFee * currency.PicoDollar(messageSize) | ||
if storageFeePerByte/currency.PicoDollar(messageSize) != rates.StorageFee { | ||
return 0, fmt.Errorf("storage fee calculation overflow") | ||
} | ||
|
||
totalStorageFee := storageFeePerByte * currency.PicoDollar(storageDurationDays) | ||
if totalStorageFee/currency.PicoDollar(storageDurationDays) != storageFeePerByte { | ||
return 0, fmt.Errorf("storage fee calculation overflow") | ||
} | ||
|
||
return rates.MessageFee + totalStorageFee, nil | ||
} | ||
|
||
func (c *FeeCalculator) CalculateCongestionFee( | ||
messageTime time.Time, | ||
congestionUnits int64, | ||
) (currency.PicoDollar, error) { | ||
if congestionUnits < 0 || congestionUnits > 100 { | ||
return 0, fmt.Errorf( | ||
"congestionPercent must be between 0 and 100, got %d", | ||
congestionUnits, | ||
) | ||
} | ||
|
||
if congestionUnits == 0 { | ||
return 0, nil | ||
} | ||
|
||
rates, err := c.ratesFetcher.GetRates(messageTime) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
result := rates.CongestionFee * currency.PicoDollar(congestionUnits) | ||
if result/currency.PicoDollar(congestionUnits) != rates.CongestionFee { | ||
return 0, fmt.Errorf("congestion fee calculation overflow") | ||
} | ||
return result, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package fees | ||
|
||
import ( | ||
"math" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
const ( | ||
RATE_MESSAGE_FEE = 100 | ||
RATE_STORAGE_FEE = 50 | ||
RATE_CONGESTION_FEE = 200 | ||
) | ||
|
||
func setupCalculator() *FeeCalculator { | ||
rates := &Rates{ | ||
MessageFee: RATE_MESSAGE_FEE, | ||
StorageFee: RATE_STORAGE_FEE, | ||
CongestionFee: RATE_CONGESTION_FEE, | ||
} | ||
ratesFetcher := NewFixedRatesFetcher(rates) | ||
return NewFeeCalculator(ratesFetcher) | ||
} | ||
|
||
func TestCalculateBaseFee(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
messageSize := int64(100) | ||
storageDurationDays := int64(1) | ||
|
||
baseFee, err := calculator.CalculateBaseFee(messageTime, messageSize, storageDurationDays) | ||
require.NoError(t, err) | ||
|
||
expectedFee := RATE_MESSAGE_FEE + (RATE_STORAGE_FEE * messageSize * storageDurationDays) | ||
require.Equal(t, currency.PicoDollar(expectedFee), baseFee) | ||
} | ||
|
||
func TestCalculateCongestionFee(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
congestionPercent := int64(50) | ||
|
||
congestionFee, err := calculator.CalculateCongestionFee(messageTime, congestionPercent) | ||
require.NoError(t, err) | ||
|
||
expectedFee := RATE_CONGESTION_FEE * congestionPercent | ||
require.Equal(t, currency.PicoDollar(expectedFee), congestionFee) | ||
} | ||
|
||
func TestOverflow(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
|
||
// Test overflow in CalculateBaseFee | ||
messageSize := math.MaxInt64 | ||
storageDurationDays := math.MaxInt64 | ||
|
||
_, err := calculator.CalculateBaseFee( | ||
messageTime, | ||
int64(messageSize), | ||
int64(storageDurationDays), | ||
) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "storage fee calculation overflow") | ||
} | ||
|
||
func TestInvalidCongestionPercent(t *testing.T) { | ||
calculator := setupCalculator() | ||
|
||
messageTime := time.Now() | ||
|
||
_, err := calculator.CalculateCongestionFee(messageTime, int64(101)) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") | ||
|
||
_, err = calculator.CalculateCongestionFee(messageTime, int64(-1)) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "congestionPercent must be between 0 and 100") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package fees | ||
|
||
import "time" | ||
|
||
// A fixed fee schedule that doesn't rely on the blockchain. | ||
// Used primarily for testing, and until we have fees onchain. | ||
type FixedRatesFetcher struct { | ||
rates *Rates | ||
} | ||
|
||
func NewFixedRatesFetcher(rates *Rates) *FixedRatesFetcher { | ||
return &FixedRatesFetcher{rates: rates} | ||
} | ||
|
||
func (f *FixedRatesFetcher) GetRates(_messageTime time.Time) (*Rates, error) { | ||
return f.rates, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package fees | ||
|
||
import ( | ||
"time" | ||
|
||
"github.com/xmtp/xmtpd/pkg/currency" | ||
) | ||
|
||
// Rates containt the cost for each fee component at a given message time. | ||
// Values in the rates struct are denominated in USD PicoDollars | ||
type Rates struct { | ||
MessageFee currency.PicoDollar // The flat per-message fee | ||
StorageFee currency.PicoDollar // The fee per byte-day of storage | ||
CongestionFee currency.PicoDollar // The fee per unit of congestion | ||
} | ||
|
||
// The RatesFetcher is responsible for loading the rates for a given message time. | ||
// This allows us to roll out new rates over time, and apply them to messages consistently. | ||
type IRatesFetcher interface { | ||
GetRates(messageTime time.Time) (*Rates, error) | ||
} | ||
|
||
type IFeeCalculator interface { | ||
CalculateBaseFee( | ||
messageTime time.Time, | ||
messageSize int64, | ||
storageDurationDays int64, | ||
) (currency.PicoDollar, error) | ||
CalculateCongestionFee( | ||
messageTime time.Time, | ||
congestionUnits int64, | ||
) (currency.PicoDollar, error) | ||
} |