Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
AminHP committed Mar 17, 2023
1 parent 2dbd96d commit 917ef62
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
23 changes: 12 additions & 11 deletions gym_mtsim/envs/mt_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import gym
from gym import spaces
from gym.utils import seeding
import sys

from ..simulator import MtSimulator, OrderType


Expand All @@ -30,6 +30,7 @@ def __init__(
fee: Union[float, Callable[[str], float]]=0.0005,
symbol_max_orders: int=1, multiprocessing_processes: Optional[int]=None
) -> None:

# validations
assert len(original_simulator.symbols_data) > 0, "no data available"
assert len(original_simulator.symbols_info) > 0, "no data available"
Expand Down Expand Up @@ -67,17 +68,18 @@ def __init__(

# spaces
self.action_space = spaces.Box(
low=-1e10, high=1e10, dtype=np.float64,
low=-1e2, high=1e2, dtype=np.float64,
shape=(len(self.trading_symbols) * (self.symbol_max_orders + 2),)
) # symbol -> [close_order_i(logit), hold(logit), volume]

INF = 1e10
self.observation_space = spaces.Dict({
'balance': spaces.Box(low=-sys.float_info.max, high=sys.float_info.max, shape=(1,), dtype=np.float64),
'equity': spaces.Box(low=-sys.float_info.max, high=sys.float_info.max, shape=(1,), dtype=np.float64),
'margin': spaces.Box(low=-sys.float_info.max, high=sys.float_info.max, shape=(1,), dtype=np.float64),
'features': spaces.Box(low=-sys.float_info.max, high=sys.float_info.max, shape=self.features_shape, dtype=np.float64),
'balance': spaces.Box(low=-INF, high=INF, shape=(1,), dtype=np.float64),
'equity': spaces.Box(low=-INF, high=INF, shape=(1,), dtype=np.float64),
'margin': spaces.Box(low=-INF, high=INF, shape=(1,), dtype=np.float64),
'features': spaces.Box(low=-INF, high=INF, shape=self.features_shape, dtype=np.float64),
'orders': spaces.Box(
low=-sys.float_info.max, high=sys.float_info.max, dtype=np.float64,
low=-INF, high=INF, dtype=np.float64,
shape=(len(self.trading_symbols), self.symbol_max_orders, 3)
) # symbol, order_i -> [entry_price, volume, profit]
})
Expand Down Expand Up @@ -107,16 +109,16 @@ def reset(self) -> Dict[str, np.ndarray]:

def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
orders_info, closed_orders_info = self._apply_action(action)

self._current_tick += 1

if self._current_tick == self._end_tick:
self._done = True

dt = self.time_points[self._current_tick] - self.time_points[self._current_tick - 1]
self.simulator.tick(dt)

step_reward = self._calculate_reward()

info = self._create_info(
orders=orders_info, closed_orders=closed_orders_info, step_reward=step_reward
)
Expand Down Expand Up @@ -233,9 +235,8 @@ def _get_observation(self) -> Dict[str, np.ndarray]:
def _calculate_reward(self) -> float:
prev_equity = self.history[-1]['equity']
current_equity = self.simulator.equity
step_reward = current_equity - prev_equity
step_reward = current_equity - prev_equity
return step_reward



def _create_info(self, **kwargs: Any) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions gym_mtsim/simulator/mt_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pickle
from datetime import datetime, timedelta
import sys

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -47,7 +47,7 @@ def free_margin(self) -> float:
def margin_level(self) -> float:
margin = round(self.margin, 6)
if margin == 0.:
return sys.float_info.max
return float('inf')
return self.equity / margin


Expand Down

0 comments on commit 917ef62

Please sign in to comment.