Skip to content

Commit 785c872

Browse files
Update APIs and finish implementation for stocks data
1 parent da3a889 commit 785c872

9 files changed

+369
-381
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.pytest_cache
22
__pycache__
3-
.DS_Store
3+
.DS_Store
4+
*.csv

cli/main.py

+80-66
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
Get quotes:
99
python main.py stocks historical quotes MSFT 20240101 20240131 --interval 3600000
1010
11-
Save output to a file:
12-
python main.py stocks historical eod-report AAPL 20240101 20240131 --output-file aapl_eod.csv
13-
1411
Stocks Snapshot Data:
1512
Get real-time quotes:
1613
python main.py stocks snapshot quotes AAPL
@@ -46,19 +43,8 @@
4643
stocks_app.add_typer(snapshot_app, name="snapshot")
4744
app.add_typer(options_app, name="options")
4845

49-
historical_data = ThetaDataStocksHistorical(enable_logging=True, use_df=True)
50-
snapshot_data = ThetaDataStocksSnapshot(enable_logging=True, use_df=True)
51-
52-
53-
def save_output(result: pd.DataFrame | dict | None, output_file: Optional[str]):
54-
if isinstance(result, pd.DataFrame):
55-
if output_file:
56-
result.to_csv(output_file, index=False)
57-
typer.echo(f"Data saved to {output_file}")
58-
else:
59-
typer.echo(result.to_string())
60-
else:
61-
typer.echo(result)
46+
historical_data = ThetaDataStocksHistorical()
47+
snapshot_data = ThetaDataStocksSnapshot()
6248

6349

6450
def with_spinner(func):
@@ -80,12 +66,15 @@ def wrapper(*args, **kwargs):
8066
# Historical commands
8167
@historical_app.command(name="eod-report")
8268
@with_spinner
83-
def eod_report(
84-
symbol: str, start_date: str, end_date: str, output_file: Optional[str] = None
85-
):
69+
def eod_report(symbol: str, start_date: str, end_date: str):
8670
"""Get end-of-day report for a given symbol and date range."""
87-
result = historical_data.get_eod_report(symbol, start_date, end_date)
88-
save_output(result, output_file)
71+
result = historical_data.get_eod_report(
72+
symbol, start_date, end_date, write_csv=True
73+
)
74+
if result is not None:
75+
typer.echo("Data retrieved successfully")
76+
else:
77+
typer.echo("Failed to retrieve data")
8978

9079

9180
@historical_app.command(name="quotes")
@@ -95,11 +84,15 @@ def historical_quotes(
9584
start_date: str,
9685
end_date: str,
9786
interval: str = "900000",
98-
output_file: Optional[str] = None,
9987
):
10088
"""Get historical quotes for a given symbol and date range."""
101-
result = historical_data.get_quotes(symbol, start_date, end_date, interval)
102-
save_output(result, output_file)
89+
result = historical_data.get_quotes(
90+
symbol, start_date, end_date, interval, write_csv=True
91+
)
92+
if result is not None:
93+
typer.echo("Data retrieved successfully")
94+
else:
95+
typer.echo("Failed to retrieve data")
10396

10497

10598
@historical_app.command(name="ohlc")
@@ -109,96 +102,117 @@ def historical_ohlc(
109102
start_date: str,
110103
end_date: str,
111104
interval: str = "900000",
112-
output_file: Optional[str] = None,
113105
):
114106
"""Get historical OHLC data for a given symbol and date range."""
115-
result = historical_data.get_ohlc(symbol, start_date, end_date, interval)
116-
save_output(result, output_file)
107+
result = historical_data.get_ohlc(
108+
symbol, start_date, end_date, interval, write_csv=True
109+
)
110+
if result is not None:
111+
typer.echo("Data retrieved successfully")
112+
else:
113+
typer.echo("Failed to retrieve data")
117114

118115

119116
@historical_app.command(name="trades")
120117
@with_spinner
121-
def historical_trades(
122-
symbol: str, start_date: str, end_date: str, output_file: Optional[str] = None
123-
):
118+
def historical_trades(symbol: str, start_date: str, end_date: str):
124119
"""Get historical trade data for a given symbol and date range."""
125-
result = historical_data.get_trades(symbol, start_date, end_date)
126-
save_output(result, output_file)
120+
result = historical_data.get_trades(symbol, start_date, end_date, write_csv=True)
121+
if result is not None:
122+
typer.echo("Data retrieved successfully")
123+
else:
124+
typer.echo("Failed to retrieve data")
127125

128126

129127
@historical_app.command(name="trade-quote")
130128
@with_spinner
131-
def trade_quote(
132-
symbol: str, start_date: str, end_date: str, output_file: Optional[str] = None
133-
):
129+
def trade_quote(symbol: str, start_date: str, end_date: str):
134130
"""Get historical trade and quote data for a given symbol and date range."""
135-
result = historical_data.get_trade_quote(symbol, start_date, end_date)
136-
save_output(result, output_file)
131+
result = historical_data.get_trade_quote(
132+
symbol, start_date, end_date, write_csv=True
133+
)
134+
if result is not None:
135+
typer.echo("Data retrieved successfully")
136+
else:
137+
typer.echo("Failed to retrieve data")
137138

138139

139140
@historical_app.command(name="splits")
140141
@with_spinner
141-
def splits(
142-
symbol: str, start_date: str, end_date: str, output_file: Optional[str] = None
143-
):
142+
def splits(symbol: str, start_date: str, end_date: str):
144143
"""Get stock split data for a given symbol and date range."""
145-
result = historical_data.get_splits(symbol, start_date, end_date)
146-
save_output(result, output_file)
144+
result = historical_data.get_splits(symbol, start_date, end_date, write_csv=True)
145+
if result is not None:
146+
typer.echo("Data retrieved successfully")
147+
else:
148+
typer.echo("Failed to retrieve data")
147149

148150

149151
@historical_app.command(name="dividends")
150152
@with_spinner
151-
def dividends(
152-
symbol: str, start_date: str, end_date: str, output_file: Optional[str] = None
153-
):
153+
def dividends(symbol: str, start_date: str, end_date: str):
154154
"""Get dividend data for a given symbol and date range."""
155-
result = historical_data.get_dividends(symbol, start_date, end_date)
156-
save_output(result, output_file)
155+
result = historical_data.get_dividends(symbol, start_date, end_date, write_csv=True)
156+
if result is not None:
157+
typer.echo("Data retrieved successfully")
158+
else:
159+
typer.echo("Failed to retrieve data")
157160

158161

159162
# Snapshot commands
160163
@snapshot_app.command(name="quotes")
161164
@with_spinner
162-
def snapshot_quotes(
163-
symbol: str, venue: Optional[str] = None, output_file: Optional[str] = None
164-
):
165+
def snapshot_quotes(symbol: str, venue: Optional[str] = None):
165166
"""Get real-time quotes for a given symbol."""
166-
result = snapshot_data.get_quotes(symbol, venue)
167-
save_output(result, output_file)
167+
result = snapshot_data.get_quotes(symbol, venue, write_csv=True)
168+
if result is not None:
169+
typer.echo("Data retrieved successfully")
170+
else:
171+
typer.echo("Failed to retrieve data")
168172

169173

170174
@snapshot_app.command(name="bulk-quotes")
171175
@with_spinner
172-
def bulk_quotes(
173-
symbols: List[str], venue: Optional[str] = None, output_file: Optional[str] = None
174-
):
176+
def bulk_quotes(symbols: List[str], venue: Optional[str] = None):
175177
"""Get real-time quotes for multiple symbols."""
176-
result = snapshot_data.get_bulk_quotes(symbols, venue)
177-
save_output(result, output_file)
178+
result = snapshot_data.get_bulk_quotes(symbols, venue, write_csv=True)
179+
if result is not None:
180+
typer.echo("Data retrieved successfully")
181+
else:
182+
typer.echo("Failed to retrieve data")
178183

179184

180185
@snapshot_app.command(name="ohlc")
181186
@with_spinner
182-
def snapshot_ohlc(symbol: str, output_file: Optional[str] = None):
187+
def snapshot_ohlc(symbol: str):
183188
"""Get real-time OHLC data for a given symbol."""
184-
result = snapshot_data.get_ohlc(symbol)
185-
save_output(result, output_file)
189+
result = snapshot_data.get_ohlc(symbol, write_csv=True)
190+
if result is not None:
191+
typer.echo("Data retrieved successfully")
192+
else:
193+
typer.echo("Failed to retrieve data")
186194

187195

188196
@snapshot_app.command(name="bulk-ohlc")
189197
@with_spinner
190-
def bulk_ohlc(symbols: List[str], output_file: Optional[str] = None):
198+
def bulk_ohlc(symbols: List[str]):
191199
"""Get real-time OHLC data for multiple symbols."""
192-
result = snapshot_data.get_bulk_ohlc(symbols)
193-
save_output(result, output_file)
200+
result = snapshot_data.get_bulk_ohlc(symbols, write_csv=True)
201+
if result is not None:
202+
typer.echo("Data retrieved successfully")
203+
else:
204+
typer.echo("Failed to retrieve data")
194205

195206

196207
@snapshot_app.command(name="trades")
197208
@with_spinner
198-
def snapshot_trades(symbol: str, output_file: Optional[str] = None):
209+
def snapshot_trades(symbol: str):
199210
"""Get real-time trade data for a given symbol."""
200-
result = snapshot_data.get_trades(symbol)
201-
save_output(result, output_file)
211+
result = snapshot_data.get_trades(symbol, write_csv=True)
212+
if result is not None:
213+
typer.echo("Data retrieved successfully")
214+
else:
215+
typer.echo("Failed to retrieve data")
202216

203217

204218
if __name__ == "__main__":

examples/stocks_historical_examples.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from src.stocks_historical import ThetaDataStocksHistorical
99

10-
historical_data = ThetaDataStocksHistorical(enable_logging=True, use_df=True)
10+
historical_data = ThetaDataStocksHistorical(log_level="DEBUG")
1111

1212

1313
def example_runner(func):
@@ -87,10 +87,10 @@ def intel_dividends_example():
8787
# Toggle example cases
8888
if __name__ == "__main__":
8989
run_examples = {
90-
"apple_eod_example": False,
91-
"microsoft_quotes_example": False,
92-
"google_ohlc_example": False,
93-
"tesla_trades_example": False,
90+
"apple_eod_example": True,
91+
"microsoft_quotes_example": True,
92+
"google_ohlc_example": True,
93+
"tesla_trades_example": True,
9494
"amazon_trade_quote_example": True,
9595
"nvidia_splits_example": True,
9696
"intel_dividends_example": True,

examples/stocks_snapshot_examples.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,37 @@
88

99
from src.stocks import ThetaDataStocksSnapshot
1010

11-
snapshot_data = ThetaDataStocksSnapshot(enable_logging=True, use_df=True)
11+
snapshot_data = ThetaDataStocksSnapshot(log_level="INFO", output_dir="./output")
1212

1313

1414
@example_runner
1515
def apple_quotes_example():
16-
return snapshot_data.get_quotes("AAPL")
16+
return snapshot_data.get_quotes("AAPL", write_csv=True)
1717

1818

1919
@example_runner
2020
def microsoft_quotes_nqb_example():
21-
return snapshot_data.get_quotes("MSFT", venue="nqb")
21+
return snapshot_data.get_quotes("MSFT", venue="nqb", write_csv=True)
2222

2323

2424
@example_runner
2525
def bulk_quotes_example():
26-
return snapshot_data.get_bulk_quotes(["GOOGL", "AMZN", "TSLA"])
26+
return snapshot_data.get_bulk_quotes(["GOOGL", "AMZN", "TSLA"], write_csv=True)
2727

2828

2929
@example_runner
3030
def nvidia_ohlc_example():
31-
return snapshot_data.get_ohlc("NVDA")
31+
return snapshot_data.get_ohlc("NVDA", write_csv=True)
3232

3333

3434
@example_runner
3535
def bulk_ohlc_example():
36-
return snapshot_data.get_bulk_ohlc(["INTC", "AMD", "QCOM"])
36+
return snapshot_data.get_bulk_ohlc(["INTC", "AMD", "QCOM"], write_csv=True)
3737

3838

3939
@example_runner
4040
def tesla_trades_example():
41-
return snapshot_data.get_trades("TSLA")
41+
return snapshot_data.get_trades("TSLA", write_csv=True)
4242

4343

4444
# Toggle example cases

0 commit comments

Comments
 (0)