Skip to content

Commit bfcaa91

Browse files
authored
Merge pull request vnpy#5 from Edanflame/stream
添加tickoverview并增加流式写入参数
2 parents b019c71 + 31daedc commit bfcaa91

File tree

1 file changed

+73
-3
lines changed

1 file changed

+73
-3
lines changed

vnpy_postgresql/postgresql_database.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vnpy.trader.database import (
1919
BaseDatabase,
2020
BarOverview,
21+
TickOverview,
2122
DB_TZ,
2223
convert_tz
2324
)
@@ -128,16 +129,32 @@ class Meta:
128129
indexes = ((("symbol", "exchange", "interval"), True),)
129130

130131

132+
class DbTickOverview(Model):
133+
"""Tick汇总数据表映射对象"""
134+
135+
id: AutoField = AutoField()
136+
137+
symbol: str = CharField()
138+
exchange: str = CharField()
139+
count: int = IntegerField()
140+
start: datetime = DateTimeField()
141+
end: datetime = DateTimeField()
142+
143+
class Meta:
144+
database: PeeweePostgresqlDatabase = db
145+
indexes: tuple = ((("symbol", "exchange"), True),)
146+
147+
131148
class PostgresqlDatabase(BaseDatabase):
132149
"""PostgreSQL数据库接口"""
133150

134151
def __init__(self) -> None:
135152
""""""
136153
self.db = db
137154
self.db.connect()
138-
self.db.create_tables([DbBarData, DbTickData, DbBarOverview])
155+
self.db.create_tables([DbBarData, DbTickData, DbBarOverview, DbTickOverview])
139156

140-
def save_bar_data(self, bars: List[BarData]) -> bool:
157+
def save_bar_data(self, bars: List[BarData], stream: bool = False) -> bool:
141158
"""保存K线数据"""
142159
# 读取主键参数
143160
bar = bars[0]
@@ -186,6 +203,9 @@ def save_bar_data(self, bars: List[BarData]) -> bool:
186203
overview.start = bars[0].datetime
187204
overview.end = bars[-1].datetime
188205
overview.count = len(bars)
206+
elif stream:
207+
overview.end = bars[-1].datetime
208+
overview.count += len(bars)
189209
else:
190210
overview.start = min(bars[0].datetime, overview.start)
191211
overview.end = max(bars[-1].datetime, overview.end)
@@ -201,8 +221,13 @@ def save_bar_data(self, bars: List[BarData]) -> bool:
201221

202222
return True
203223

204-
def save_tick_data(self, ticks: List[TickData]) -> bool:
224+
def save_tick_data(self, ticks: List[TickData], stream: bool = False) -> bool:
205225
"""保存TICK数据"""
226+
# 读取主键参数
227+
tick: TickData = ticks[0]
228+
symbol: str = tick.symbol
229+
exchange: Exchange = tick.exchange
230+
206231
# 将TickData数据转换为字典,并调整时区
207232
data = []
208233

@@ -227,6 +252,34 @@ def save_tick_data(self, ticks: List[TickData]) -> bool:
227252
),
228253
).execute()
229254

255+
# 更新Tick汇总数据
256+
overview: DbTickOverview = DbTickOverview.get_or_none(
257+
DbTickOverview.symbol == symbol,
258+
DbTickOverview.exchange == exchange.value,
259+
)
260+
261+
if not overview:
262+
overview: DbTickOverview = DbTickOverview()
263+
overview.symbol = symbol
264+
overview.exchange = exchange.value
265+
overview.start = ticks[0].datetime
266+
overview.end = ticks[-1].datetime
267+
overview.count = len(ticks)
268+
elif stream:
269+
overview.end = ticks[-1].datetime
270+
overview.count += len(ticks)
271+
else:
272+
overview.start = min(ticks[0].datetime, overview.start)
273+
overview.end = max(ticks[-1].datetime, overview.end)
274+
275+
s: ModelSelect = DbTickData.select().where(
276+
(DbTickData.symbol == symbol)
277+
& (DbTickData.exchange == exchange.value)
278+
)
279+
overview.count = s.count()
280+
281+
overview.save()
282+
230283
return True
231284

232285
def load_bar_data(
@@ -364,6 +417,14 @@ def delete_tick_data(
364417
& (DbTickData.exchange == exchange.value)
365418
)
366419
count = d.execute()
420+
421+
# 删除Tick汇总数据
422+
d2: ModelDelete = DbTickOverview.delete().where(
423+
(DbTickOverview.symbol == symbol)
424+
& (DbTickOverview.exchange == exchange.value)
425+
)
426+
d2.execute()
427+
367428
return count
368429

369430
def get_bar_overview(self) -> List[BarOverview]:
@@ -382,6 +443,15 @@ def get_bar_overview(self) -> List[BarOverview]:
382443
overviews.append(overview)
383444
return overviews
384445

446+
def get_tick_overview(self) -> List[TickOverview]:
447+
"""查询数据库中的Tick汇总信息"""
448+
s: ModelSelect = DbTickOverview.select()
449+
overviews: list = []
450+
for overview in s:
451+
overview.exchange = Exchange(overview.exchange)
452+
overviews.append(overview)
453+
return overviews
454+
385455
def init_bar_overview(self) -> None:
386456
"""初始化数据库中的K线汇总信息"""
387457
s: ModelSelect = (

0 commit comments

Comments
 (0)