Skip to content

Commit a54ca2b

Browse files
authored
Merge pull request vnpy#3 from noranhe/main
[Add] 类型声明
2 parents bfcaa91 + 81ee463 commit a54ca2b

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

vnpy_postgresql/postgresql_database.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vnpy.trader.setting import SETTINGS
2626

2727

28-
db = PeeweePostgresqlDatabase(
28+
db: PeeweePostgresqlDatabase = PeeweePostgresqlDatabase(
2929
database=SETTINGS["database.database"],
3030
user=SETTINGS["database.user"],
3131
password=SETTINGS["database.password"],
@@ -38,7 +38,7 @@
3838
class DbBarData(Model):
3939
"""K线数据表映射对象"""
4040

41-
id = AutoField()
41+
id: AutoField = AutoField()
4242

4343
symbol: str = CharField()
4444
exchange: str = CharField()
@@ -54,14 +54,14 @@ class DbBarData(Model):
5454
close_price: float = FloatField()
5555

5656
class Meta:
57-
database = db
58-
indexes = ((("symbol", "exchange", "interval", "datetime"), True),)
57+
database: PeeweePostgresqlDatabase = db
58+
indexes: tuple = ((("symbol", "exchange", "interval", "datetime"), True),)
5959

6060

6161
class DbTickData(Model):
6262
"""TICK数据表映射对象"""
6363

64-
id = AutoField()
64+
id: AutoField = AutoField()
6565

6666
symbol: str = CharField()
6767
exchange: str = CharField()
@@ -108,14 +108,14 @@ class DbTickData(Model):
108108
localtime: datetime = DateTimeField(null=True)
109109

110110
class Meta:
111-
database = db
112-
indexes = ((("symbol", "exchange", "datetime"), True),)
111+
database: PeeweePostgresqlDatabase = db
112+
indexes: tuple = ((("symbol", "exchange", "datetime"), True),)
113113

114114

115115
class DbBarOverview(Model):
116116
"""K线汇总数据表映射对象"""
117117

118-
id = AutoField()
118+
id: AutoField = AutoField()
119119

120120
symbol: str = CharField()
121121
exchange: str = CharField()
@@ -125,8 +125,8 @@ class DbBarOverview(Model):
125125
end: datetime = DateTimeField()
126126

127127
class Meta:
128-
database = db
129-
indexes = ((("symbol", "exchange", "interval"), True),)
128+
database: PeeweePostgresqlDatabase = db
129+
indexes: tuple = ((("symbol", "exchange", "interval"), True),)
130130

131131

132132
class DbTickOverview(Model):
@@ -150,25 +150,25 @@ class PostgresqlDatabase(BaseDatabase):
150150

151151
def __init__(self) -> None:
152152
""""""
153-
self.db = db
153+
self.db: PeeweePostgresqlDatabase = db
154154
self.db.connect()
155155
self.db.create_tables([DbBarData, DbTickData, DbBarOverview, DbTickOverview])
156156

157157
def save_bar_data(self, bars: List[BarData], stream: bool = False) -> bool:
158158
"""保存K线数据"""
159159
# 读取主键参数
160-
bar = bars[0]
161-
symbol = bar.symbol
162-
exchange = bar.exchange
163-
interval = bar.interval
160+
bar: BarData = bars[0]
161+
symbol: str = bar.symbol
162+
exchange: Exchange = bar.exchange
163+
interval: Interval = bar.interval
164164

165165
# 将BarData数据转换为字典,并调整时区
166-
data = []
166+
data: list = []
167167

168168
for bar in bars:
169169
bar.datetime = convert_tz(bar.datetime)
170170

171-
d = bar.__dict__
171+
d: dict = bar.__dict__
172172
d["exchange"] = d["exchange"].value
173173
d["interval"] = d["interval"].value
174174
d.pop("gateway_name")
@@ -229,12 +229,12 @@ def save_tick_data(self, ticks: List[TickData], stream: bool = False) -> bool:
229229
exchange: Exchange = tick.exchange
230230

231231
# 将TickData数据转换为字典,并调整时区
232-
data = []
232+
data: list = []
233233

234234
for tick in ticks:
235235
tick.datetime = convert_tz(tick.datetime)
236236

237-
d = tick.__dict__
237+
d: dict = tick.__dict__
238238
d["exchange"] = d["exchange"].value
239239
d.pop("gateway_name")
240240
d.pop("vt_symbol")
@@ -303,7 +303,7 @@ def load_bar_data(
303303

304304
bars: List[BarData] = []
305305
for db_bar in s:
306-
bar = BarData(
306+
bar: BarData = BarData(
307307
symbol=db_bar.symbol,
308308
exchange=Exchange(db_bar.exchange),
309309
datetime=datetime.fromtimestamp(db_bar.datetime.timestamp(), DB_TZ),
@@ -340,7 +340,7 @@ def load_tick_data(
340340

341341
ticks: List[TickData] = []
342342
for db_tick in s:
343-
tick = TickData(
343+
tick: TickData = TickData(
344344
symbol=db_tick.symbol,
345345
exchange=Exchange(db_tick.exchange),
346346
datetime=datetime.fromtimestamp(db_tick.datetime.timestamp(), DB_TZ),
@@ -395,7 +395,7 @@ def delete_bar_data(
395395
& (DbBarData.exchange == exchange.value)
396396
& (DbBarData.interval == interval.value)
397397
)
398-
count = d.execute()
398+
count: int = d.execute()
399399

400400
# 删除K线汇总数据
401401
d2: ModelDelete = DbBarOverview.delete().where(
@@ -416,7 +416,7 @@ def delete_tick_data(
416416
(DbTickData.symbol == symbol)
417417
& (DbTickData.exchange == exchange.value)
418418
)
419-
count = d.execute()
419+
count: int = d.execute()
420420

421421
# 删除Tick汇总数据
422422
d2: ModelDelete = DbTickOverview.delete().where(
@@ -430,13 +430,13 @@ def delete_tick_data(
430430
def get_bar_overview(self) -> List[BarOverview]:
431431
"""查询数据库中的K线汇总信息"""
432432
# 如果已有K线,但缺失汇总信息,则执行初始化
433-
data_count = DbBarData.select().count()
434-
overview_count = DbBarOverview.select().count()
433+
data_count: int = DbBarData.select().count()
434+
overview_count: int = DbBarOverview.select().count()
435435
if data_count and not overview_count:
436436
self.init_bar_overview()
437437

438438
s: ModelSelect = DbBarOverview.select()
439-
overviews = []
439+
overviews: List[BarOverview] = []
440440
for overview in s:
441441
overview.exchange = Exchange(overview.exchange)
442442
overview.interval = Interval(overview.interval)
@@ -468,7 +468,7 @@ def init_bar_overview(self) -> None:
468468
)
469469

470470
for data in s:
471-
overview = DbBarOverview()
471+
overview: DbBarOverview = DbBarOverview()
472472
overview.symbol = data.symbol
473473
overview.exchange = data.exchange
474474
overview.interval = data.interval

0 commit comments

Comments
 (0)