Skip to content

Make FeeRate in swap pallet type-safe #1771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: devnet-ready
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pallets/subtensor/src/tests/staking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4878,8 +4878,8 @@ fn test_unstake_full_amount() {
// Check if balance has increased accordingly
let balance_after = SubtensorModule::get_coldkey_balance(&coldkey);
let actual_balance_increase = (balance_after - balance_before) as f64;
let fee_rate = pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)) as f64
/ u16::MAX as f64;
let fee_rate =
pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)).as_normalized_f64();
let expected_balance_increase = amount as f64 * (1. - fee_rate) / (1. + fee_rate);
assert_abs_diff_eq!(
actual_balance_increase,
Expand Down Expand Up @@ -4931,8 +4931,8 @@ fn test_swap_fees_tao_correctness() {
let netuid = add_dynamic_network(&owner_hotkey, &owner_coldkey);
SubtensorModule::add_balance_to_coldkey_account(&owner_coldkey, owner_balance_before);
SubtensorModule::add_balance_to_coldkey_account(&coldkey, user_balance_before);
let fee_rate = pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)) as f64
/ u16::MAX as f64;
let fee_rate =
pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)).as_normalized_f64();
pallet_subtensor_swap::EnabledUserLiquidity::<Test>::insert(NetUid::from(netuid), true);

// Forse-set alpha in and tao reserve to make price equal 0.25
Expand Down Expand Up @@ -5072,8 +5072,8 @@ fn test_default_min_stake_sufficiency() {
let netuid = add_dynamic_network(&owner_hotkey, &owner_coldkey);
SubtensorModule::add_balance_to_coldkey_account(&owner_coldkey, owner_balance_before);
SubtensorModule::add_balance_to_coldkey_account(&coldkey, user_balance_before);
let fee_rate = pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)) as f64
/ u16::MAX as f64;
let fee_rate =
pallet_subtensor_swap::FeeRate::<Test>::get(NetUid::from(netuid)).as_normalized_f64();

// Set some extreme, but realistic TAO and Alpha reserves to minimize slippage
// 1% of TAO max supply
Expand Down
2 changes: 1 addition & 1 deletion pallets/swap/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ frame-system = { workspace = true }
log = { workspace = true }
safe-math = { workspace = true }
scale-info = { workspace = true }
serde = { workspace = true, optional = true }
serde = { workspace = true }
sp-arithmetic = { workspace = true }
sp-core = { workspace = true }
sp-io = { workspace = true }
Expand Down
3 changes: 2 additions & 1 deletion pallets/swap/src/benchmarking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use substrate_fixed::types::{I64F64, U64F64};
use subtensor_runtime_common::NetUid;

use crate::{
FeeRateT,
pallet::{
AlphaSqrtPrice, Call, Config, CurrentLiquidity, CurrentTick, EnabledUserLiquidity, Pallet,
Positions, SwapV3Initialized,
Expand All @@ -26,7 +27,7 @@ mod benchmarks {
#[benchmark]
fn set_fee_rate() {
let netuid = NetUid::from(1);
let rate: u16 = 100; // Some arbitrary fee rate value
let rate = FeeRateT::from(100); // Some arbitrary fee rate value

#[extrinsic_call]
set_fee_rate(RawOrigin::Root, netuid, rate);
Expand Down
102 changes: 102 additions & 0 deletions pallets/swap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
#![cfg_attr(not(feature = "std"), no_std)]
use core::fmt::{self, Display, Formatter};
use core::ops::{Add, Mul};

use codec::{Compact, CompactAs, Decode, Encode, Error as CodecError, MaxEncodedLen};
use frame_support::pallet_prelude::*;
use safe_math::*;
use scale_info::TypeInfo;
use serde::{Deserialize, Serialize};
use substrate_fixed::types::U64F64;
use subtensor_macros::freeze_struct;
use subtensor_swap_interface::OrderType;

pub mod pallet;
Expand All @@ -17,3 +25,97 @@ pub mod benchmarking;
pub(crate) mod mock;

type SqrtPrice = U64F64;

#[freeze_struct("91109ca21993a3bf")]
#[repr(transparent)]
#[derive(
Deserialize,
Serialize,
Clone,
Copy,
Decode,
Default,
Encode,
Eq,
Hash,
MaxEncodedLen,
Ord,
PartialEq,
PartialOrd,
RuntimeDebug,
TypeInfo,
)]
#[serde(transparent)]
pub struct FeeRateT(u16);

impl FeeRateT {
pub fn as_normalized_f64(&self) -> f64 {
(self.0 as f64) / (u16::MAX as f64)
}

pub fn as_normalized_fixed(&self) -> U64F64 {
U64F64::saturating_from_num(self.0).safe_div(U64F64::from_num(u16::MAX))
}

pub fn as_f64(&self) -> f64 {
self.0 as f64
}

pub fn as_fixed(&self) -> U64F64 {
U64F64::saturating_from_num(self.0)
}
}

impl Display for FeeRateT {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.as_normalized_f64(), f)
}
}

impl CompactAs for FeeRateT {
type As = u16;

fn encode_as(&self) -> &Self::As {
&self.0
}

fn decode_from(v: Self::As) -> Result<Self, CodecError> {
Ok(Self(v))
}
}

impl From<Compact<FeeRateT>> for FeeRateT {
fn from(c: Compact<FeeRateT>) -> Self {
c.0
}
}

impl From<FeeRateT> for u16 {
fn from(val: FeeRateT) -> Self {
val.0
}
}

impl From<u16> for FeeRateT {
fn from(value: u16) -> Self {
Self(value)
}
}

impl Add<FeeRateT> for FeeRateT {
type Output = Self;

#[allow(clippy::arithmetic_side_effects)]
fn add(self, rhs: Self) -> Self::Output {
(self.0 + rhs.0).into()
}
}

impl Mul<FeeRateT> for FeeRateT {
type Output = Self;

#[allow(clippy::arithmetic_side_effects)]
fn mul(self, rhs: Self) -> Self::Output {
(self.0 * rhs.0).into()
}
}
4 changes: 2 additions & 2 deletions pallets/swap/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use sp_runtime::{
};
use subtensor_runtime_common::{BalanceOps, NetUid, SubnetInfo};

use crate::pallet::EnabledUserLiquidity;
use crate::{FeeRateT, pallet::EnabledUserLiquidity};

construct_runtime!(
pub enum Test {
Expand Down Expand Up @@ -77,7 +77,7 @@ impl system::Config for Test {

parameter_types! {
pub const SwapProtocolId: PalletId = PalletId(*b"ten/swap");
pub const MaxFeeRate: u16 = 10000; // 15.26%
pub static MaxFeeRate: FeeRateT = FeeRateT::from(10000); // 15.26%
pub const MaxPositions: u32 = 100;
pub const MinimumLiquidity: u64 = 1_000;
pub const MinimumReserves: NonZeroU64 = NonZeroU64::new(1).unwrap();
Expand Down
5 changes: 2 additions & 3 deletions pallets/swap/src/pallet/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl<T: Config> SwapStep<T> {
// in case if we hit the limit price or the edge price.
if recalculate_fee {
let u16_max = U64F64::saturating_from_num(u16::MAX);
let fee_rate = U64F64::saturating_from_num(FeeRate::<T>::get(self.netuid));
let fee_rate = FeeRate::<T>::get(self.netuid).as_fixed();
let delta_fixed = U64F64::saturating_from_num(self.delta_in);
self.fee = delta_fixed
.saturating_mul(fee_rate.safe_div(u16_max.saturating_sub(fee_rate)))
Expand Down Expand Up @@ -518,8 +518,7 @@ impl<T: Config> Pallet<T> {
///
/// Fee is provided by state ops as u16-normalized value.
fn calculate_fee_amount(netuid: NetUid, amount: u64) -> u64 {
let fee_rate = U64F64::saturating_from_num(FeeRate::<T>::get(netuid))
.safe_div(U64F64::saturating_from_num(u16::MAX));
let fee_rate = FeeRate::<T>::get(netuid).as_normalized_fixed();

U64F64::saturating_from_num(amount)
.saturating_mul(fee_rate)
Expand Down
17 changes: 11 additions & 6 deletions pallets/swap/src/pallet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use substrate_fixed::types::U64F64;
use subtensor_runtime_common::{BalanceOps, NetUid, SubnetInfo};

use crate::{
FeeRateT,
position::{Position, PositionId},
tick::{LayerLevel, Tick, TickIndex},
weights::WeightInfo,
Expand Down Expand Up @@ -46,7 +47,7 @@ mod pallet {

/// The maximum fee rate that can be set
#[pallet::constant]
type MaxFeeRate: Get<u16>;
type MaxFeeRate: Get<FeeRateT>;

/// The maximum number of positions a user can have
#[pallet::constant]
Expand All @@ -66,13 +67,13 @@ mod pallet {

/// Default fee rate if not set
#[pallet::type_value]
pub fn DefaultFeeRate() -> u16 {
196 // 0.3 %
pub fn DefaultFeeRate() -> FeeRateT {
196.into() // 0.3 %
}

/// The fee rate applied to swaps per subnet, normalized value between 0 and u16::MAX
#[pallet::storage]
pub type FeeRate<T> = StorageMap<_, Twox64Concat, NetUid, u16, ValueQuery, DefaultFeeRate>;
pub type FeeRate<T> = StorageMap<_, Twox64Concat, NetUid, FeeRateT, ValueQuery, DefaultFeeRate>;

// Global accrued fees in tao per subnet
#[pallet::storage]
Expand Down Expand Up @@ -143,7 +144,7 @@ mod pallet {
#[pallet::generate_deposit(pub(super) fn deposit_event)]
pub enum Event<T: Config> {
/// Event emitted when the fee rate has been updated for a subnet
FeeRateSet { netuid: NetUid, rate: u16 },
FeeRateSet { netuid: NetUid, rate: FeeRateT },

/// Event emitted when user liquidity operations are enabled for a subnet.
/// First enable even indicates a switch from V2 to V3 swap.
Expand Down Expand Up @@ -237,7 +238,11 @@ mod pallet {
/// Only callable by the admin origin
#[pallet::call_index(0)]
#[pallet::weight(<T as pallet::Config>::WeightInfo::set_fee_rate())]
pub fn set_fee_rate(origin: OriginFor<T>, netuid: NetUid, rate: u16) -> DispatchResult {
pub fn set_fee_rate(
origin: OriginFor<T>,
netuid: NetUid,
rate: FeeRateT,
) -> DispatchResult {
if ensure_root(origin.clone()).is_err() {
let account_id: T::AccountId = ensure_signed(origin)?;
ensure!(
Expand Down
16 changes: 8 additions & 8 deletions pallets/swap/src/pallet/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use substrate_fixed::types::U96F32;
use subtensor_runtime_common::NetUid;

use super::*;
use crate::{OrderType, SqrtPrice, mock::*};
use crate::{FeeRateT, OrderType, SqrtPrice, mock::*};

// this function is used to convert price (NON-SQRT price!) to TickIndex. it's only utility for
// testing, all the implementation logic is based on sqrt prices
Expand Down Expand Up @@ -80,7 +80,7 @@ mod dispatchables {
fn test_set_fee_rate() {
new_test_ext().execute_with(|| {
let netuid = NetUid::from(1);
let fee_rate = 500; // 0.76% fee
let fee_rate = FeeRateT::from(500); // 0.76% fee

assert_noop!(
Swap::set_fee_rate(RuntimeOrigin::signed(666), netuid, fee_rate),
Expand All @@ -92,7 +92,7 @@ mod dispatchables {
// Check that fee rate was set correctly
assert_eq!(FeeRate::<Test>::get(netuid), fee_rate);

let fee_rate = fee_rate * 2;
let fee_rate = fee_rate * 2.into();
assert_ok!(Swap::set_fee_rate(
RuntimeOrigin::signed(1),
netuid,
Expand All @@ -101,7 +101,7 @@ mod dispatchables {
assert_eq!(FeeRate::<Test>::get(netuid), fee_rate);

// Verify fee rate validation - should fail if too high
let too_high_fee = MaxFeeRate::get() + 1;
let too_high_fee = MaxFeeRate::get() + FeeRateT::from(1);
assert_noop!(
Swap::set_fee_rate(RuntimeOrigin::root(), netuid, too_high_fee),
Error::<Test>::FeeRateTooHigh
Expand Down Expand Up @@ -740,7 +740,7 @@ fn test_swap_basic() {
);

// Expected fee amount
let fee_rate = FeeRate::<Test>::get(netuid) as f64 / u16::MAX as f64;
let fee_rate = FeeRate::<Test>::get(netuid).as_normalized_f64();
let expected_fee = (liquidity as f64 * fee_rate) as u64;

// Global fees should be updated
Expand Down Expand Up @@ -1002,7 +1002,7 @@ fn test_swap_single_position() {
);

// Expected fee amount
let fee_rate = FeeRate::<Test>::get(netuid) as f64 / u16::MAX as f64;
let fee_rate = FeeRate::<Test>::get(netuid).as_normalized_f64();
let expected_fee =
(order_liquidity - order_liquidity / (1.0 + fee_rate)) as u64;

Expand Down Expand Up @@ -1455,7 +1455,7 @@ fn test_swap_fee_correctness() {
assert_eq!(position.tick_high, tick_high);

// Check that 50% of fees were credited to the position
let fee_rate = FeeRate::<Test>::get(NetUid::from(netuid)) as f64 / u16::MAX as f64;
let fee_rate = FeeRate::<Test>::get(NetUid::from(netuid)).as_normalized_f64();
let (actual_fee_tao, actual_fee_alpha) = position.collect_fees();
let expected_fee = (fee_rate * (liquidity / 10) as f64 * 0.5) as u64;

Expand Down Expand Up @@ -1701,7 +1701,7 @@ fn test_wrapping_fees() {
position.tick_high.try_to_sqrt_price().unwrap(),
);

let fee_rate = FeeRate::<Test>::get(netuid) as f64 / u16::MAX as f64;
let fee_rate = FeeRate::<Test>::get(netuid).as_normalized_f64();

log::trace!("fee_rate: {:.6}", fee_rate);
log::trace!("position.liquidity: {}", position.liquidity);
Expand Down
Loading