diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..fa34a52f --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +# The default "help" goal nicely prints all the available goals based on the funny looking ## comments. +# Source: https://marmelab.com/blog/2016/02/29/auto-documented-makefile.html +.DEFAULT_GOAL := help +.PHONY: help +help: ## Display this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the SDK and its dependencies using poetry + poetry install + +.PHONY: lint +lint: ## Run the linters + poetry run black --check alpaca/ tests/ + +.PHONY: generate +generate: ## Generate the documentation + ./tools/scripts/generate-docs.sh + +.PHONY: test +test: ## Run the unit tests + poetry run pytest diff --git a/alpaca/data/historical/__init__.py b/alpaca/data/historical/__init__.py index c86900b8..ff74300a 100644 --- a/alpaca/data/historical/__init__.py +++ b/alpaca/data/historical/__init__.py @@ -1,7 +1,13 @@ -from alpaca.data.historical.stock import StockHistoricalDataClient from alpaca.data.historical.crypto import CryptoHistoricalDataClient +from alpaca.data.historical.news import NewsClient +from alpaca.data.historical.option import OptionHistoricalDataClient +from alpaca.data.historical.screener import ScreenerClient +from alpaca.data.historical.stock import StockHistoricalDataClient __all__ = [ - "StockHistoricalDataClient", "CryptoHistoricalDataClient", + "StockHistoricalDataClient", + "NewsClient", + "OptionHistoricalDataClient", + "ScreenerClient", ] diff --git a/alpaca/data/historical/crypto.py b/alpaca/data/historical/crypto.py index 67223e12..cae67c74 100644 --- a/alpaca/data/historical/crypto.py +++ b/alpaca/data/historical/crypto.py @@ -1,29 +1,28 @@ from collections import defaultdict -from typing import Union, Optional, List, Dict +from typing import Dict, List, Optional, Union from alpaca.common.constants import DATA_V2_MAX_LIMIT -from alpaca.common.types import RawData from alpaca.common.enums import BaseURL from alpaca.common.rest import RESTClient -from alpaca.common.types import Credentials -from alpaca.data import Snapshot, Bar +from alpaca.common.types import Credentials, RawData +from alpaca.data import Bar, Snapshot +from alpaca.data.enums import CryptoFeed +from alpaca.data.historical.stock import DataExtensionType from alpaca.data.historical.utils import ( - parse_obj_as_symbol_dict, - format_latest_data_response, format_dataset_response, + format_latest_data_response, + parse_obj_as_symbol_dict, ) -from alpaca.data.models import BarSet, QuoteSet, TradeSet, Orderbook, Trade, Quote -from alpaca.data.historical.stock import DataExtensionType +from alpaca.data.models import BarSet, Orderbook, Quote, Trade, TradeSet from alpaca.data.requests import ( CryptoBarsRequest, - CryptoTradesRequest, - CryptoLatestTradeRequest, + CryptoLatestBarRequest, + CryptoLatestOrderbookRequest, CryptoLatestQuoteRequest, + CryptoLatestTradeRequest, CryptoSnapshotRequest, - CryptoLatestOrderbookRequest, - CryptoLatestBarRequest, + CryptoTradesRequest, ) -from alpaca.data.enums import CryptoFeed class CryptoHistoricalDataClient(RESTClient): diff --git a/alpaca/data/historical/screener.py b/alpaca/data/historical/screener.py index b39de158..50175528 100644 --- a/alpaca/data/historical/screener.py +++ b/alpaca/data/historical/screener.py @@ -1,14 +1,10 @@ from typing import Optional, Union -from alpaca.common.rest import RESTClient - from alpaca.common.enums import BaseURL - -from alpaca.data.requests import MarketMoversRequest, MostActivesRequest - -from alpaca.data.models.screener import MostActives, Movers - +from alpaca.common.rest import RESTClient from alpaca.common.types import RawData +from alpaca.data.models.screener import MostActives, Movers +from alpaca.data.requests import MarketMoversRequest, MostActivesRequest class ScreenerClient(RESTClient): diff --git a/alpaca/data/historical/stock.py b/alpaca/data/historical/stock.py index cf75329c..7ccb5a2a 100644 --- a/alpaca/data/historical/stock.py +++ b/alpaca/data/historical/stock.py @@ -1,29 +1,28 @@ from collections import defaultdict from enum import Enum -from typing import List, Optional, Union, Type, Dict +from typing import Dict, List, Optional, Type, Union +from alpaca.common.constants import DATA_V2_MAX_LIMIT from alpaca.common.enums import BaseURL -from alpaca.common.rest import RESTClient, HTTPResult +from alpaca.common.rest import HTTPResult, RESTClient from alpaca.common.types import RawData -from alpaca.data import Quote, Trade, Snapshot, Bar +from alpaca.data import Bar, Quote, Snapshot, Trade from alpaca.data.historical.utils import ( - parse_obj_as_symbol_dict, - format_latest_data_response, format_dataset_response, + format_latest_data_response, format_snapshot_data, + parse_obj_as_symbol_dict, ) - from alpaca.data.models import BarSet, QuoteSet, TradeSet from alpaca.data.requests import ( StockBarsRequest, - StockQuotesRequest, - StockTradesRequest, - StockLatestTradeRequest, + StockLatestBarRequest, StockLatestQuoteRequest, + StockLatestTradeRequest, + StockQuotesRequest, StockSnapshotRequest, - StockLatestBarRequest, + StockTradesRequest, ) -from alpaca.common.constants import DATA_V2_MAX_LIMIT class DataExtensionType(Enum): diff --git a/alpaca/data/historical/utils.py b/alpaca/data/historical/utils.py index 3c01e433..6b238e0c 100644 --- a/alpaca/data/historical/utils.py +++ b/alpaca/data/historical/utils.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Type, Dict +from typing import Dict, Type from alpaca.common import HTTPResult, RawData diff --git a/alpaca/data/live/__init__.py b/alpaca/data/live/__init__.py index 07c1fd86..1f105e5b 100644 --- a/alpaca/data/live/__init__.py +++ b/alpaca/data/live/__init__.py @@ -1,7 +1,11 @@ -from alpaca.data.live.stock import StockDataStream from alpaca.data.live.crypto import CryptoDataStream +from alpaca.data.live.news import NewsDataStream +from alpaca.data.live.option import OptionDataStream +from alpaca.data.live.stock import StockDataStream __all__ = [ - "StockDataStream", "CryptoDataStream", + "NewsDataStream", + "OptionDataStream", + "StockDataStream", ] diff --git a/alpaca/data/live/crypto.py b/alpaca/data/live/crypto.py index cf95608f..5b06bd5b 100644 --- a/alpaca/data/live/crypto.py +++ b/alpaca/data/live/crypto.py @@ -1,14 +1,17 @@ -from alpaca.common.websocket import BaseStream -from typing import Optional, Dict +from typing import Awaitable, Callable, Dict, Optional, Union + from alpaca.common.enums import BaseURL from alpaca.data.enums import CryptoFeed +from alpaca.data.live.websocket import DataStream +from alpaca.data.models.bars import Bar +from alpaca.data.models.orderbooks import Orderbook +from alpaca.data.models.quotes import Quote +from alpaca.data.models.trades import Trade -class CryptoDataStream(BaseStream): +class CryptoDataStream(DataStream): """ A WebSocket client for streaming live crypto data. - - See BaseStream for more information on implementation and the methods available. """ def __init__( @@ -27,6 +30,7 @@ def __init__( api_key (str): Alpaca API key. secret_key (str): Alpaca API secret key. raw_data (bool, optional): Whether to return wrapped data or raw API data. Defaults to False. + feed (CryptoFeed, optional): Which crypto feed to use. Defaults to US. websocket_params (Optional[Dict], optional): Any parameters for configuring websocket connection. Defaults to None. url_override (Optional[str]): If specified allows you to override the base url the client points to for proxy/testing. Defaults to None. @@ -42,3 +46,123 @@ def __init__( raw_data=raw_data, websocket_params=websocket_params, ) + + def subscribe_trades( + self, handler: Callable[[Union[Trade, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to trades. + + Args: + handler (Callable[[Union[Trade, Dict]], Awaitable[None]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["trades"]) + + def subscribe_quotes( + self, handler: Callable[[Union[Quote, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to quotes + + Args: + handler (Callable[[Union[Quote, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["quotes"]) + + def subscribe_bars( + self, handler: Callable[[Union[Quote, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to minute bars + + Args: + handler (Callable[[Union[Quote, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["bars"]) + + def subscribe_updated_bars( + self, handler: Callable[[Union[Bar, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to updated minute bars + + Args: + handler (Callable[[Union[Bar, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["updatedBars"]) + + def subscribe_daily_bars( + self, handler: Callable[[Union[Bar, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to daily bars + + Args: + handler (Callable[[Union[Bar, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["dailyBars"]) + + def subscribe_orderbooks( + self, handler: Callable[[Union[Orderbook, Dict]], Awaitable[None]], *symbols + ) -> None: + """Subscribe to orderbooks + + Args: + handler (Callable[[Union[Bar, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["orderbooks"]) + + def unsubscribe_trades(self, *symbols: str) -> None: + """Unsubscribe from trades + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("trades", symbols) + + def unsubscribe_quotes(self, *symbols: str) -> None: + """Unsubscribe from quotes + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("quotes", symbols) + + def unsubscribe_bars(self, *symbols: str) -> None: + """Unsubscribe from minute bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("bars", symbols) + + def unsubscribe_updated_bars(self, *symbols: str) -> None: + """Unsubscribe from updated bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("updatedBars", symbols) + + def unsubscribe_daily_bars(self, *symbols: str) -> None: + """Unsubscribe from daily bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("dailyBars", symbols) + + def unsubscribe_orderbooks(self, *symbols: str) -> None: + """Unsubscribe from orderbooks + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("orderbooks", symbols) diff --git a/alpaca/data/live/news.py b/alpaca/data/live/news.py index 38de7c5c..99f01093 100644 --- a/alpaca/data/live/news.py +++ b/alpaca/data/live/news.py @@ -1,10 +1,11 @@ -from typing import Optional, Dict +from typing import Awaitable, Callable, Dict, Optional, Union from alpaca.common.enums import BaseURL -from alpaca.common.websocket import BaseStream +from alpaca.data.live.websocket import DataStream +from alpaca.data.models.news import News -class NewsDataStream(BaseStream): +class NewsDataStream(DataStream): """ A WebSocket client for streaming news. """ @@ -39,3 +40,23 @@ def __init__( raw_data=raw_data, websocket_params=websocket_params, ) + + def subscribe_news( + self, handler: Callable[[Union[News, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to news + + Args: + handler (Callable[[Union[News, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["news"]) + + def unsubscribe_news(self, *symbols: str) -> None: + """Unsubscribe from news + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("news", symbols) diff --git a/alpaca/data/live/option.py b/alpaca/data/live/option.py index 48a943d0..48407608 100644 --- a/alpaca/data/live/option.py +++ b/alpaca/data/live/option.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional +from typing import Awaitable, Callable, Dict, Optional, Union from alpaca.common.enums import BaseURL -from alpaca.common.websocket import BaseStream from alpaca.data.enums import OptionsFeed +from alpaca.data.live.websocket import DataStream +from alpaca.data.models.quotes import Quote +from alpaca.data.models.trades import Trade -class OptionDataStream(BaseStream): +class OptionDataStream(DataStream): """ A WebSocket client for streaming live option data. - - See BaseStream for more information on implementation and the methods available. """ def __init__( @@ -28,10 +28,12 @@ def __init__( api_key (str): Alpaca API key. secret_key (str): Alpaca API secret key. raw_data (bool): Whether to return wrapped data or raw API data. Defaults to False. - feed (OptionsFeed): The source feed of the data. `opra` or `indicative`. Defaults to `indicative` - websocket_params (Optional[Dict], optional): Any parameters for configuring websocket connection. Defaults to None. + feed (OptionsFeed): The source feed of the data. `opra` or `indicative`. + Defaults to `indicative`. `opra` requires a subscription. + websocket_params (Optional[Dict], optional): Any parameters for configuring websocket + connection. Defaults to None. url_override (Optional[str]): If specified allows you to override the base url the client - points to for proxy/testing. Defaults to None. + points to for proxy/testing. Defaults to None. """ super().__init__( endpoint=( @@ -44,3 +46,43 @@ def __init__( raw_data=raw_data, websocket_params=websocket_params, ) + + def subscribe_trades( + self, handler: Callable[[Union[Trade, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to trades. + + Args: + handler (Callable[[Union[Trade, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["trades"]) + + def subscribe_quotes( + self, handler: Callable[[Union[Quote, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to quotes + + Args: + handler (Callable[[Union[Quote, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["quotes"]) + + def unsubscribe_trades(self, *symbols: str) -> None: + """Unsubscribe from trades + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("trades", symbols) + + def unsubscribe_quotes(self, *symbols: str) -> None: + """Unsubscribe from quotes + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("quotes", symbols) diff --git a/alpaca/data/live/stock.py b/alpaca/data/live/stock.py index 5f00a11b..052d8f6a 100644 --- a/alpaca/data/live/stock.py +++ b/alpaca/data/live/stock.py @@ -1,16 +1,17 @@ -from typing import Optional, Dict +import asyncio +from typing import Awaitable, Callable, Dict, Optional, Union from alpaca.common.enums import BaseURL -from alpaca.common.websocket import BaseStream from alpaca.data.enums import DataFeed +from alpaca.data.live.websocket import DataStream +from alpaca.data.models.bars import Bar +from alpaca.data.models.quotes import Quote +from alpaca.data.models.trades import Trade, TradeCancel, TradeCorrection, TradingStatus -class StockDataStream(BaseStream): +class StockDataStream(DataStream): """ - A WebSocket client for streaming live stock data via IEX or SIP depending on your market data - subscription. - - See BaseStream for more information on implementation and the methods available. + A WebSocket client for streaming live stock data. """ def __init__( @@ -29,16 +30,17 @@ def __init__( api_key (str): Alpaca API key. secret_key (str): Alpaca API secret key. raw_data (bool, optional): Whether to return wrapped data or raw API data. Defaults to False. - feed (DataFeed, optional): Which market data feed to use; IEX or SIP. Defaults to IEX. + feed (DataFeed, optional): Which market data feed to use; IEX or SIP. + Defaults to IEX. SIP requires a subscription. websocket_params (Optional[Dict], optional): Any parameters for configuring websocket connection. Defaults to None. url_override (Optional[str]): If specified allows you to override the base url the client - points to for proxy/testing. Defaults to None. + points to for proxy/testing. Defaults to None. Raises: ValueError: Only IEX or SIP market data feeds are supported """ - if feed == DataFeed.OTC: - raise ValueError("OTC not supported for live data feeds") + if feed not in (DataFeed.IEX, DataFeed.SIP): + raise ValueError("only IEX and SIP feeds ar supported") super().__init__( endpoint=( @@ -51,3 +53,147 @@ def __init__( raw_data=raw_data, websocket_params=websocket_params, ) + + def subscribe_trades( + self, handler: Callable[[Union[Trade, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to trades. + + Args: + handler (Callable[[Union[Trade, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["trades"]) + + def subscribe_quotes( + self, handler: Callable[[Union[Quote, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to quotes + + Args: + handler (Callable[[Union[Trade, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["quotes"]) + + def subscribe_bars( + self, handler: Callable[[Union[Bar, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to minute bars + + Args: + handler (Callable[[Union[Trade, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["bars"]) + + def subscribe_updated_bars( + self, handler: Callable[[Union[Bar, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to updated minute bars + + Args: + handler (Callable[[Union[Bar, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["updatedBars"]) + + def subscribe_daily_bars( + self, handler: Callable[[Union[Bar, Dict]], Awaitable[None]], *symbols: str + ) -> None: + """Subscribe to daily bars + + Args: + handler (Callable[[Union[Bar, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["dailyBars"]) + + def subscribe_trading_statuses( + self, handler: Callable[[Union[TradingStatus, Dict]], Awaitable[None]], *symbols + ) -> None: + """Subscribe to trading statuses (halts, resumes) + + Args: + handler (Callable[[Union[TradingStatus, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + *symbols: List of ticker symbols to subscribe to. "*" for everything. + """ + self._subscribe(handler, symbols, self._handlers["statuses"]) + + def register_trade_corrections( + self, handler: Callable[[Union[TradeCorrection, Dict]], Awaitable[None]] + ) -> None: + """Register a trade correction handler. You can only subscribe to trade corrections by + subscribing to the underlying trades. + + Args: + handler (Callable[[Union[TradeCorrection, Dict]]): The coroutine callback + function to handle the incoming data. + """ + self._handlers["corrections"] = {"*": handler} + + def register_trade_cancels( + self, handler: Callable[[Union[TradeCancel, Dict]], Awaitable[None]] + ) -> None: + """Register a trade cancel handler. You can only subscribe to trade cancels by + subscribing to the underlying trades. + + Args: + handler (Callable[[Union[TradeCancel, Dict]], Awaitable[None]]): The coroutine callback + function to handle the incoming data. + """ + self._handlers["cancelErrors"] = {"*": handler} + + def unsubscribe_trades(self, *symbols: str) -> None: + """Unsubscribe from trades + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("trades", symbols) + + def unsubscribe_quotes(self, *symbols: str) -> None: + """Unsubscribe from quotes + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("quotes", symbols) + + def unsubscribe_bars(self, *symbols: str) -> None: + """Unsubscribe from minute bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("bars", symbols) + + def unsubscribe_updated_bars(self, *symbols: str) -> None: + """Unsubscribe from updated bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("updatedBars", symbols) + + def unsubscribe_daily_bars(self, *symbols: str) -> None: + """Unsubscribe from daily bars + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("dailyBars", symbols) + + def unsubscribe_trading_statuses(self, *symbols: str) -> None: + """Unsubscribe from trading statuses + + Args: + *symbols (str): List of ticker symbols to unsubscribe from. "*" for everything. + """ + self._unsubscribe("statuses", symbols) diff --git a/alpaca/common/websocket.py b/alpaca/data/live/websocket.py similarity index 55% rename from alpaca/common/websocket.py rename to alpaca/data/live/websocket.py index db76444a..c8966abe 100644 --- a/alpaca/common/websocket.py +++ b/alpaca/data/live/websocket.py @@ -2,7 +2,7 @@ import logging import queue from collections import defaultdict -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import msgpack import websockets @@ -10,14 +10,23 @@ from alpaca import __version__ from alpaca.common.types import RawData -from alpaca.data.models import Bar, Quote, Trade +from alpaca.data.models import ( + Bar, + News, + Orderbook, + Quote, + Trade, + TradeCancel, + TradeCorrection, + TradingStatus, +) log = logging.getLogger(__name__) -class BaseStream: +class DataStream: """ - A base class for extracting out common functionality for websockets + A base class for extracting out common functionality for data websockets """ def __init__( @@ -28,7 +37,7 @@ def __init__( raw_data: bool = False, websocket_params: Optional[Dict] = None, ) -> None: - """_summary_ + """Creates a new DataStream instance. Args: endpoint (str): The websocket endpoint to connect to @@ -48,9 +57,15 @@ def __init__( self._handlers = { "trades": {}, "quotes": {}, + "orderbooks": {}, "bars": {}, "updatedBars": {}, "dailyBars": {}, + "statuses": {}, + "lulds": {}, + "news": {}, + "corrections": {}, + "cancelErrors": {}, } self._name = "data" self._should_run = True @@ -78,6 +93,7 @@ async def _connect(self) -> None: "User-Agent": "APCA-PY/" + __version__, } + log.info(f"connecting to {self._endpoint}") self._ws = await websockets.connect( self._endpoint, extra_headers=extra_headers, @@ -89,7 +105,7 @@ async def _connect(self) -> None: raise ValueError("connected message not received") async def _auth(self) -> None: - """Authenicates with API keys after a successful connection is established. + """Authenticates with API keys after a successful connection is established. Raises: ValueError: Raised if authentication is unsuccessful @@ -116,7 +132,7 @@ async def _start_ws(self) -> None: """ await self._connect() await self._auth() - log.info(f"connected to: {self._endpoint}") + log.info(f"connected to {self._endpoint}") async def close(self) -> None: """Closes the websocket connection.""" @@ -151,79 +167,98 @@ async def _consume(self) -> None: # to break the loop when needed pass - def _cast(self, msg_type: str, msg: Dict) -> Union[BaseModel, RawData]: + def _cast(self, msg: Dict) -> Union[BaseModel, RawData]: """Parses data from websocket message if raw_data is False, otherwise - returns raw websocket message + returns the raw websocket message. Args: - msg_type (str): The type of data contained in messaged. ('t' for trade, 'q' for quote, etc) msg (Dict): The message containing market data Returns: - Union[BaseModel, RawData]: The raw or parsed live data + Union[BaseModel, RawData]: The raw or parsed message """ - result = msg - if not self._raw_data: - if "t" in msg: - msg["t"] = msg["t"].to_datetime() - - if "S" not in msg: - return msg - - if msg_type == "t": - result = Trade(msg["S"], msg) - - elif msg_type == "q": - result = Quote(msg["S"], msg) - - elif msg_type in ("b", "u", "d"): - result = Bar(msg["S"], msg) - - return result + if self._raw_data: + return msg + msg_type = msg.get("T") + if "t" in msg: + msg["t"] = msg["t"].to_datetime() + if msg_type == "n": + msg["created_at"] = msg["created_at"].to_datetime() + msg["updated_at"] = msg["updated_at"].to_datetime() + return News(msg) + if "S" not in msg: + return msg + if msg_type == "t": + return Trade(msg["S"], msg) + if msg_type == "q": + return Quote(msg["S"], msg) + if msg_type == "o": + return Orderbook(msg["S"], msg) + if msg_type in ("b", "u", "d"): + return Bar(msg["S"], msg) + if msg_type == "s": + return TradingStatus(msg["S"], msg) + if msg_type == "c": + return TradeCorrection(msg["S"], msg) + if msg_type == "x": + return TradeCancel(msg["S"], msg) + return msg async def _dispatch(self, msg: Dict) -> None: - """Distributes message from websocket connection to appropriate handler + """Distributes the message from websocket connection to the appropriate handler. Args: msg (Dict): The message from the websocket connection """ msg_type = msg.get("T") - symbol = msg.get("S") - if msg_type == "t": - handler = self._handlers["trades"].get( - symbol, self._handlers["trades"].get("*", None) - ) - if handler: - await handler(self._cast(msg_type, msg)) - elif msg_type == "q": - handler = self._handlers["quotes"].get( - symbol, self._handlers["quotes"].get("*", None) - ) - if handler: - await handler(self._cast(msg_type, msg)) - elif msg_type == "b": - handler = self._handlers["bars"].get( - symbol, self._handlers["bars"].get("*", None) - ) - if handler: - await handler(self._cast(msg_type, msg)) - elif msg_type == "u": - handler = self._handlers["updatedBars"].get( - symbol, self._handlers["updatedBars"].get("*", None) - ) - if handler: - await handler(self._cast(msg_type, msg)) - elif msg_type == "d": - handler = self._handlers["dailyBars"].get( - symbol, self._handlers["dailyBars"].get("*", None) - ) - if handler: - await handler(self._cast(msg_type, msg)) - elif msg_type == "subscription": - sub = [f"{k}: {msg.get(k, [])}" for k in self._handlers] + if msg_type == "subscription": + sub = [f"{k}: {msg.get(k, [])}" for k in self._handlers if msg.get(k)] log.info(f'subscribed to {", ".join(sub)}') - elif msg_type == "error": + return + + if msg_type == "error": log.error(f'error: {msg.get("msg")} ({msg.get("code")})') + return + + if msg_type == "n": + symbols = msg.get("symbols", "*") + star_handler_called = False + handlers_to_call = [] + news = self._cast(msg) + for symbol in set(symbols): + if symbol in self._handlers["news"]: + handler = self._handlers["news"].get(symbol) + elif not star_handler_called: + handler = self._handlers["news"].get("*") + star_handler_called = True + else: + handler = None + if handler: + handlers_to_call.append(handler(news)) + if handlers_to_call: + await asyncio.gather(*handlers_to_call) + return + + channel_types = { + "t": "trades", + "q": "quotes", + "o": "orderbooks", + "b": "bars", + "u": "updatedBars", + "d": "dailyBars", + "s": "statuses", + "l": "lulds", + "n": "news", + "c": "corrections", + "x": "cancelErrors", + } + channel = channel_types.get(msg_type) + if not channel: + return + symbol = msg.get("S") + handler = self._handlers[channel].get(symbol, self._handlers[channel].get("*")) + if handler: + await handler(self._cast(msg)) def _subscribe( self, handler: Callable, symbols: Tuple[str], handlers: Dict @@ -239,10 +274,11 @@ def _subscribe( for symbol in symbols: handlers[symbol] = handler if self._running: - asyncio.run_coroutine_threadsafe(self._subscribe_all(), self._loop).result() + asyncio.run_coroutine_threadsafe( + self._send_subscribe_msg(), self._loop + ).result() - async def _subscribe_all(self) -> None: - """Subscribes to live data""" + async def _send_subscribe_msg(self) -> None: msg = defaultdict(list) for k, v in self._handlers.items(): if k not in ("cancelErrors", "corrections") and v: @@ -256,29 +292,21 @@ async def _subscribe_all(self) -> None: ) await self._ws.send(frames) - async def _unsubscribe( - self, trades=(), quotes=(), bars=(), updated_bars=(), daily_bars=() - ) -> None: - """Unsubscribes from data for symbols specified by the data type - we want to subscribe from. + def _unsubscribe(self, channel: str, symbols: List[str]) -> None: + if self._running: + asyncio.run_coroutine_threadsafe( + self._send_unsubscribe_msg(channel, symbols), self._loop + ).result() + for symbol in symbols: + del self._handlers[channel][symbol] - Args: - trades (tuple, optional): All symbols to unsubscribe trade data for. Defaults to (). - quotes (tuple, optional): All symbols to unsubscribe quotes data for. Defaults to (). - bars (tuple, optional): All symbols to unsubscribe minute bar data for. Defaults to (). - updated_bars (tuple, optional): All symbols to unsubscribe updated bar data for. Defaults to (). - daily_bars (tuple, optional): All symbols to unsubscribe daily bar data for. Defaults to (). - """ - if trades or quotes or bars or updated_bars or daily_bars: + async def _send_unsubscribe_msg(self, channel: str, symbols: List[str]) -> None: + if symbols: await self._ws.send( msgpack.packb( { "action": "unsubscribe", - "trades": trades, - "quotes": quotes, - "bars": bars, - "updatedBars": updated_bars, - "dailyBars": daily_bars, + channel: symbols, } ) ) @@ -312,7 +340,7 @@ async def _run_forever(self) -> None: if not self._running: log.info("starting {} websocket connection".format(self._name)) await self._start_ws() - await self._subscribe_all() + await self._send_subscribe_msg() self._running = True await self._consume() except websockets.WebSocketException as wse: @@ -331,116 +359,6 @@ async def _run_forever(self) -> None: finally: await asyncio.sleep(0) - def subscribe_trades(self, handler: Callable, *symbols) -> None: - """Subscribe to trade data for symbol inputs - - Args: - handler (Callable): The coroutine callback function to handle live trade data - *symbols: Variable string arguments for ticker identifiers to be subscribed to. - """ - self._subscribe(handler, symbols, self._handlers["trades"]) - - def subscribe_quotes(self, handler: Callable, *symbols) -> None: - """Subscribe to quote data for symbol inputs - - Args: - handler (Callable): The coroutine callback function to handle live quote data - *symbols: Variable string arguments for ticker identifiers to be subscribed to. - """ - self._subscribe(handler, symbols, self._handlers["quotes"]) - - def subscribe_bars(self, handler: Callable, *symbols) -> None: - """Subscribe to minute bar data for symbol inputs - - Args: - handler (Callable): The coroutine callback function to handle live minute bar data - *symbols: Variable string arguments for ticker identifiers to be subscribed to. - """ - self._subscribe(handler, symbols, self._handlers["bars"]) - - def subscribe_updated_bars(self, handler: Callable, *symbols) -> None: - """Subscribe to updated bar data for symbol inputs - - Args: - handler (Callable): The coroutine callback function to handle live updated bar data - *symbols: Variable string arguments for ticker identifiers to be subscribed to. - """ - self._subscribe(handler, symbols, self._handlers["updatedBars"]) - - def subscribe_daily_bars(self, handler: Callable, *symbols) -> None: - """Subscribe to daily bar data for symbol inputs - - Args: - handler (Callable): The coroutine callback function to handle live daily bar data - *symbols: Variable string arguments for ticker identifiers to be subscribed to. - """ - self._subscribe(handler, symbols, self._handlers["dailyBars"]) - - def unsubscribe_trades(self, *symbols) -> None: - """Unsubscribe from trade data for symbol inputs - - Args: - *symbols: Variable string arguments for ticker identifiers to be unsubscribed from. - """ - if self._running: - asyncio.run_coroutine_threadsafe( - self._unsubscribe(trades=symbols), self._loop - ).result() - for symbol in symbols: - del self._handlers["trades"][symbol] - - def unsubscribe_quotes(self, *symbols) -> None: - """Unsubscribe from quote data for symbol inputs - - Args: - *symbols: Variable string arguments for ticker identifiers to be unsubscribed from. - """ - if self._running: - asyncio.run_coroutine_threadsafe( - self._unsubscribe(quotes=symbols), self._loop - ).result() - for symbol in symbols: - del self._handlers["quotes"][symbol] - - def unsubscribe_bars(self, *symbols) -> None: - """Unsubscribe from minute bar data for symbol inputs - - Args: - *symbols: Variable string arguments for ticker identifiers to be unsubscribed from. - """ - if self._running: - asyncio.run_coroutine_threadsafe( - self._unsubscribe(bars=symbols), self._loop - ).result() - for symbol in symbols: - del self._handlers["bars"][symbol] - - def unsubscribe_updated_bars(self, *symbols) -> None: - """Unsubscribe from updated bar data for symbol inputs - - Args: - *symbols: Variable string arguments for ticker identifiers to be unsubscribed from. - """ - if self._running: - asyncio.get_event_loop().run_until_complete( - self._unsubscribe(updated_bars=symbols) - ) - for symbol in symbols: - del self._handlers["updatedBars"][symbol] - - def unsubscribe_daily_bars(self, *symbols) -> None: - """Unsubscribe from daily bar data for symbol inputs - - Args: - *symbols: Variable string arguments for ticker identifiers to be unsubscribed from. - """ - if self._running: - asyncio.run_coroutine_threadsafe( - self._unsubscribe(daily_bars=symbols), self._loop - ).result() - for symbol in symbols: - del self._handlers["dailyBars"][symbol] - def run(self) -> None: """Starts up the websocket connection's event loop""" try: diff --git a/alpaca/data/mappings.py b/alpaca/data/mappings.py index 2284f202..6ba92b0b 100644 --- a/alpaca/data/mappings.py +++ b/alpaca/data/mappings.py @@ -47,9 +47,38 @@ "t": "timestamp", "b": "bids", "a": "asks", + "r": "reset", } -ORDERBOOK_QUOTE_MAPPING: Dict[str, str] = { +TRADING_STATUS_MAPPING: Dict[str, str] = { + "t": "timestamp", + "sc": "status_code", + "sm": "status_message", + "rc": "reason_code", + "rm": "reason_message", + "z": "tape", +} + +TRADE_CANCEL_MAPPING: Dict[str, str] = { + "t": "timestamp", "p": "price", "s": "size", + "x": "exchange", + "i": "id", + "a": "action", + "z": "tape", +} + +TRADE_CORRECTION_MAPPING: Dict[str, str] = { + "t": "timestamp", + "x": "exchange", + "oi": "original_id", + "op": "original_price", + "os": "original_size", + "oc": "original_conditions", + "ci": "corrected_id", + "cp": "corrected_price", + "cs": "corrected_size", + "cc": "corrected_conditions", + "z": "tape", } diff --git a/alpaca/data/models/__init__.py b/alpaca/data/models/__init__.py index 7b1bb559..0db35cc9 100644 --- a/alpaca/data/models/__init__.py +++ b/alpaca/data/models/__init__.py @@ -1,6 +1,7 @@ from alpaca.data.models.bars import * +from alpaca.data.models.orderbooks import * +from alpaca.data.models.news import * from alpaca.data.models.quotes import * from alpaca.data.models.trades import * from alpaca.data.models.snapshots import * from alpaca.data.models.orderbooks import * -from alpaca.data.models.news import * diff --git a/alpaca/data/models/bars.py b/alpaca/data/models/bars.py index 85dbba9b..76d67b92 100644 --- a/alpaca/data/models/bars.py +++ b/alpaca/data/models/bars.py @@ -1,8 +1,6 @@ from datetime import datetime from typing import Dict, List, Optional -from pydantic import ConfigDict - from alpaca.common.models import ValidateBaseModel as BaseModel from alpaca.common.types import RawData from alpaca.data.mappings import BAR_MAPPING diff --git a/alpaca/data/models/base.py b/alpaca/data/models/base.py index 675f92df..c48aae0e 100644 --- a/alpaca/data/models/base.py +++ b/alpaca/data/models/base.py @@ -1,10 +1,8 @@ import itertools -import pprint from typing import Any, Dict, List import pandas as pd from pandas import DataFrame -from pydantic import ConfigDict from alpaca.common.models import ValidateBaseModel as BaseModel diff --git a/alpaca/data/models/news.py b/alpaca/data/models/news.py index 91d9ba8c..f49f12de 100644 --- a/alpaca/data/models/news.py +++ b/alpaca/data/models/news.py @@ -1,7 +1,5 @@ from datetime import datetime -from typing import Optional, List - -from pydantic import ConfigDict +from typing import List, Optional from alpaca.common.models import ValidateBaseModel as BaseModel from alpaca.common.types import RawData diff --git a/alpaca/data/models/orderbooks.py b/alpaca/data/models/orderbooks.py index 77c6a1c8..d928eabc 100644 --- a/alpaca/data/models/orderbooks.py +++ b/alpaca/data/models/orderbooks.py @@ -1,9 +1,10 @@ from datetime import datetime -from typing import Dict, List +from typing import List + +from pydantic import Field, TypeAdapter -from alpaca.common.types import RawData from alpaca.common.models import ValidateBaseModel as BaseModel -from pydantic import ConfigDict, TypeAdapter, Field +from alpaca.common.types import RawData from alpaca.data.mappings import ORDERBOOK_MAPPING @@ -23,12 +24,16 @@ class Orderbook(BaseModel): timestamp (datetime): The time of submission of the orderbook. bids (List[OrderbookQuote]): The list of bid quotes for the orderbook asks (List[OrderbookQuote]): The list of ask quotes for the orderbook + reset (bool): if true, the orderbook message contains the whole server side orderbook. + This indicates to the client that they should reset their orderbook. + Typically sent as the first message after subscription. """ symbol: str timestamp: datetime bids: List[OrderbookQuote] asks: List[OrderbookQuote] + reset: bool = False def __init__(self, symbol: str, raw_data: RawData) -> None: """Instantiates an Orderbook. diff --git a/alpaca/data/models/quotes.py b/alpaca/data/models/quotes.py index a884581c..78a87f11 100644 --- a/alpaca/data/models/quotes.py +++ b/alpaca/data/models/quotes.py @@ -1,8 +1,6 @@ from datetime import datetime from typing import Dict, List, Optional, Union -from pydantic import ConfigDict - from alpaca.common.models import ValidateBaseModel as BaseModel from alpaca.common.types import RawData from alpaca.data.enums import Exchange @@ -16,24 +14,24 @@ class Quote(BaseModel): Attributes: symbol (str): The ticker identifier for the security whose data forms the quote. timestamp (datetime): The time of submission of the quote. - ask_exchange (Optional[str, Exchange]): The exchange the quote ask originates. Defaults to None. - ask_price (float): The asking price of the quote. - ask_size (float): The size of the quote ask. - bid_exchange (Optional[str, Exchange]): The exchange the quote bid originates. Defaults to None. bid_price (float): The bidding price of the quote. bid_size (float): The size of the quote bid. + bid_exchange (Optional[str, Exchange]): The exchange the quote bid originates. Defaults to None. + ask_price (float): The asking price of the quote. + ask_size (float): The size of the quote ask. + ask_exchange (Optional[str, Exchange]): The exchange the quote ask originates. Defaults to None. conditions (Optional[Union[List[str], str]]): The quote conditions. Defaults to None. tape (Optional[str]): The quote tape. Defaults to None. """ symbol: str timestamp: datetime - ask_exchange: Optional[Union[str, Exchange]] = None - ask_price: float - ask_size: float - bid_exchange: Optional[Union[str, Exchange]] = None bid_price: float bid_size: float + bid_exchange: Optional[Union[str, Exchange]] = None + ask_price: float + ask_size: float + ask_exchange: Optional[Union[str, Exchange]] = None conditions: Optional[Union[List[str], str]] = None tape: Optional[str] = None diff --git a/alpaca/data/models/screener.py b/alpaca/data/models/screener.py index 2379d3eb..1a9143d5 100644 --- a/alpaca/data/models/screener.py +++ b/alpaca/data/models/screener.py @@ -1,9 +1,7 @@ from datetime import datetime from typing import List -from pydantic import ConfigDict from alpaca.common.models import ValidateBaseModel as BaseModel - from alpaca.data.enums import MarketType diff --git a/alpaca/data/models/snapshots.py b/alpaca/data/models/snapshots.py index 22d8dd23..3f1ef664 100644 --- a/alpaca/data/models/snapshots.py +++ b/alpaca/data/models/snapshots.py @@ -1,7 +1,5 @@ from typing import Dict, Optional -from pydantic import ConfigDict - from alpaca.common.models import ValidateBaseModel as BaseModel from alpaca.common.types import RawData from alpaca.data.mappings import SNAPSHOT_MAPPING diff --git a/alpaca/data/models/trades.py b/alpaca/data/models/trades.py index 472dfdfc..e44accc1 100644 --- a/alpaca/data/models/trades.py +++ b/alpaca/data/models/trades.py @@ -1,12 +1,15 @@ from datetime import datetime from typing import Dict, List, Optional, Union -from pydantic import ConfigDict - from alpaca.common.models import ValidateBaseModel as BaseModel from alpaca.common.types import RawData from alpaca.data.enums import Exchange -from alpaca.data.mappings import TRADE_MAPPING +from alpaca.data.mappings import ( + TRADE_CANCEL_MAPPING, + TRADE_CORRECTION_MAPPING, + TRADE_MAPPING, + TRADING_STATUS_MAPPING, +) from alpaca.data.models.base import BaseDataSet, TimeSeriesMixin @@ -34,7 +37,7 @@ class Trade(BaseModel): tape: Optional[str] = None def __init__(self, symbol: str, raw_data: RawData) -> None: - """Instantiates a Trade history object + """Instantiates a Trade object Args: symbol (str): The security identifier for the trade that occurred. @@ -51,7 +54,7 @@ def __init__(self, symbol: str, raw_data: RawData) -> None: class TradeSet(BaseDataSet, TimeSeriesMixin): - """A collection of Trade history objects. + """A collection of Trade objects. Attributes: data (Dict[str, List[Trade]]]): The collection of Trades keyed by symbol. @@ -74,3 +77,128 @@ def __init__(self, raw_data: RawData) -> None: ] super().__init__(data=parsed_trades) + + +class TradingStatus(BaseModel): + """Trading status update of a security, for example if a symbol gets halted. + + Attributes: + symbol (str): The ticker identifier. + timestamp (datetime): The time of trading status. + status_code (str): The tape-dependent code of the status. + status_message (str): The status message. + reason_code (str): The tape-dependent code of the halt reason. + reason_message (str): The reason message. + tape (Optional[str]): The tape (A, B or C). + """ + + symbol: str + timestamp: datetime + status_code: str + status_message: str + reason_code: str + reason_message: str + tape: str + + def __init__(self, symbol: str, raw_data: RawData) -> None: + """Instantiates a Trading status object + + Args: + symbol (str): The security identifier + raw_data (RawData): The raw data as received by API. + """ + + mapped = { + TRADING_STATUS_MAPPING.get(key): val + for key, val in raw_data.items() + if key in TRADING_STATUS_MAPPING + } + + super().__init__(symbol=symbol, **mapped) + + +class TradeCancel(BaseModel): + """Cancel of a previous trade. + + Attributes: + symbol (str): The ticker identifier. + timestamp (datetime): The timestamp. + exchange (Exchange): The exchange. + price (float): The price of the canceled trade. + size (float): The size of the canceled trade. + id (Optional[int]): The original ID of the canceled trade. + action (Optional[str]): The cancel action ("C" for cancel, "E" for error) + tape (str): The trade tape. Defaults to None. + """ + + symbol: str + timestamp: datetime + exchange: Exchange + price: float + size: float + id: Optional[int] = None + action: Optional[str] = None + tape: str + + def __init__(self, symbol: str, raw_data: RawData) -> None: + """Instantiates a trade cancel object + + Args: + symbol (str): The security identifier + raw_data (RawData): The raw data as received by API. + """ + + mapped = { + TRADE_CANCEL_MAPPING.get(key): val + for key, val in raw_data.items() + if key in TRADE_CANCEL_MAPPING + } + + super().__init__(symbol=symbol, **mapped) + + +class TradeCorrection(BaseModel): + """Correction of a previous trade. + + Attributes: + symbol (str): The ticker identifier. + timestamp (datetime): The timestamp. + exchange (Exchange): The exchange. + original_id (Optional[int]): The original ID of the corrected trade. + original_price (float): The original price of the corrected trade. + original_size (float): The original size of the corrected trade. + original_conditions (List[str]): The original conditions of the corrected trade. + corrected_id (Optional[int]): The corrected ID of the corrected trade. + corrected_price (float): The corrected price of the corrected trade. + corrected_size (float): The corrected size of the corrected trade. + corrected_conditions (List[str]): The corrected conditions of the corrected trade. + tape (str): The trade tape. Defaults to None. + """ + + symbol: str + timestamp: datetime + exchange: Exchange + original_id: Optional[int] = None + original_price: float + original_size: float + original_conditions: List[str] + corrected_id: Optional[int] = None + corrected_price: float + corrected_size: float + corrected_conditions: List[str] + tape: str + + def __init__(self, symbol: str, raw_data: RawData) -> None: + """Instantiates a trade correction object + + Args: + symbol (str): The security identifier + raw_data (RawData): The raw data as received by API. + """ + mapped = { + TRADE_CORRECTION_MAPPING.get(key): val + for key, val in raw_data.items() + if key in TRADE_CORRECTION_MAPPING + } + + super().__init__(symbol=symbol, **mapped) diff --git a/docs/conf.py b/docs/conf.py index 21d2a741..811383d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,6 +19,7 @@ # NameError: Field "model_fields" conflicts with member {} of protected namespace "model_". # ref. https://github.com/pydantic/pydantic/discussions/7763#discussioncomment-8417097 import alpaca.data.models.screener # noqa # pylint: disable=unused-import +import alpaca.data.models.news # noqa # pylint: disable=unused-import # -- Project information ----------------------------------------------------- diff --git a/poetry.lock b/poetry.lock index 29891be0..33e91654 100644 --- a/poetry.lock +++ b/poetry.lock @@ -886,8 +886,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1122,6 +1122,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.7" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.23.7-py3-none-any.whl", hash = "sha256:009b48127fbe44518a547bddd25611551b0e43ccdbf1e67d12479f569832c20b"}, + {file = "pytest_asyncio-0.23.7.tar.gz", hash = "sha256:5f5c72948f4c49e7db4f29f2521d4031f1c27f86e57b046126654083d4770268"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1172,6 +1190,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1836,4 +1855,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "1963f19043db1c48507c97e27fb2b0e8ef5d8b8b0009950a6d92eb341be45eba" +content-hash = "48eb14e16adad001835c9aa42396beffddb5461cbbeb2e8727608062d322be5f" diff --git a/pyproject.toml b/pyproject.toml index d55b259c..b9a13fb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ sseclient-py = "^1.7.2" [tool.poetry.dev-dependencies] pytest = "^7.1" +pytest-asyncio = "^0.23.7" requests-mock = "^1.9.3" black = "^24.3.0" isort = "^5.10.1" diff --git a/tests/conftest.py b/tests/conftest.py index f7368d10..33f2bbcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,8 @@ from alpaca.data.historical.screener import ScreenerClient from alpaca.trading.client import TradingClient +pytest_plugins = ("pytest_asyncio",) + @pytest.fixture def reqmock() -> Iterator[Mocker]: diff --git a/tests/data/test_historical_crypto_data.py b/tests/data/test_historical_crypto_data.py index 71c5d58e..5e85d8af 100644 --- a/tests/data/test_historical_crypto_data.py +++ b/tests/data/test_historical_crypto_data.py @@ -2,22 +2,18 @@ from datetime import datetime, timezone from typing import Dict -from alpaca.data import Quote, Trade, Bar +from alpaca.data import Bar, Quote, Trade from alpaca.data.historical.crypto import CryptoHistoricalDataClient +from alpaca.data.models import BarSet, Snapshot, TradeSet from alpaca.data.requests import ( CryptoBarsRequest, - CryptoTradesRequest, - CryptoLatestTradeRequest, + CryptoLatestBarRequest, CryptoLatestQuoteRequest, + CryptoLatestTradeRequest, CryptoSnapshotRequest, - CryptoLatestBarRequest, + CryptoTradesRequest, ) from alpaca.data.timeframe import TimeFrame -from alpaca.data.models import ( - BarSet, - Snapshot, - TradeSet, -) def test_get_crypto_bars(reqmock, crypto_client: CryptoHistoricalDataClient): diff --git a/tests/data/test_screener.py b/tests/data/test_screener.py index 2eda2bde..5ee45050 100644 --- a/tests/data/test_screener.py +++ b/tests/data/test_screener.py @@ -1,10 +1,7 @@ +from alpaca.common.enums import BaseURL from alpaca.data.historical.screener import ScreenerClient - -from alpaca.data.requests import MarketMoversRequest, MostActivesRequest - from alpaca.data.models.screener import MostActives, Movers - -from alpaca.common.enums import BaseURL +from alpaca.data.requests import MarketMoversRequest, MostActivesRequest def test_get_market_movers(reqmock, screener_client: ScreenerClient): diff --git a/tests/data/test_websockets.py b/tests/data/test_websockets.py index 2338eda1..40e196a7 100644 --- a/tests/data/test_websockets.py +++ b/tests/data/test_websockets.py @@ -1,22 +1,29 @@ from datetime import datetime + import pytest from msgpack.ext import Timestamp +from pytz import utc -from alpaca.common.websocket import BaseStream from alpaca.data.enums import Exchange from alpaca.data.models import Bar, Trade, News +from alpaca.data.live.websocket import DataStream +from alpaca.data.models import Bar, Trade +from alpaca.data.models.news import News +from alpaca.data.models.orderbooks import Orderbook, OrderbookQuote +from alpaca.data.models.quotes import Quote +from alpaca.data.models.trades import TradeCancel, TradingStatus @pytest.fixture -def ws_client() -> BaseStream: +def ws_client() -> DataStream: """Socket client fixture with pydantic models as output.""" - return BaseStream("endpoint", "key-id", "secret-key") + return DataStream("endpoint", "key-id", "secret-key") @pytest.fixture -def raw_ws_client() -> BaseStream: +def raw_ws_client() -> DataStream: """Socket client fixture with raw data output.""" - return BaseStream("endpoint", "key-id", "secret-key", raw_data=True) + return DataStream("endpoint", "key-id", "secret-key", raw_data=True) @pytest.fixture @@ -25,119 +32,203 @@ def timestamp() -> Timestamp: return Timestamp(seconds=10, nanoseconds=10) -def test_cast(ws_client: BaseStream, raw_ws_client: BaseStream, timestamp: Timestamp): - """Test the value error in case there's a different timestamp type.""" - # Bar - bar_msg_type = "b" - bar_msg_dict = { - "S": "AAPL", - "o": 177.94, - "c": 178.005, - "h": 178.005, - "l": 177.94, - "v": 8547, - "t": timestamp, - "n": 66, - "vw": 177.987562, - } - - bar_cast_msg = ws_client._cast(bar_msg_type, bar_msg_dict) - - assert type(bar_cast_msg) == Bar - - assert bar_cast_msg.symbol == "AAPL" - assert bar_cast_msg.high == 178.005 - - # Trade - trade_msg_type = "t" - trade__msg_dict = { - "T": "t", - "S": "AAPL", - "i": 6142, - "x": "V", - "p": 177.79, - "s": 90, - "c": ["@", "I"], - "z": "C", - "t": timestamp, - } - - trade_cast_msg = ws_client._cast(trade_msg_type, trade__msg_dict) - - assert type(trade_cast_msg) == Trade - - assert trade_cast_msg.symbol == "AAPL" - assert trade_cast_msg.price == 177.79 - assert trade_cast_msg.exchange == Exchange.V +def test_cast(ws_client: DataStream, raw_ws_client: DataStream, timestamp: Timestamp): + bar = ws_client._cast( + { + "T": "b", + "S": "AAPL", + "o": 177.94, + "c": 178.005, + "h": 178.005, + "l": 177.94, + "v": 8547, + "t": timestamp, + "n": 66, + "vw": 177.987562, + }, + ) + assert type(bar) == Bar + assert bar.symbol == "AAPL" + assert bar.high == 178.005 + + trade = ws_client._cast( + { + "T": "t", + "S": "AAPL", + "i": 6142, + "x": "V", + "p": 177.79, + "s": 90, + "c": ["@", "I"], + "z": "C", + "t": timestamp, + }, + ) + assert type(trade) == Trade + assert trade.symbol == "AAPL" + assert trade.price == 177.79 + assert trade.exchange == Exchange.V + + quote = ws_client._cast( + { + "T": "q", + "S": "SPIP", + "bx": "V", + "bp": 25.41, + "bs": 35, + "ax": "V", + "ap": 25.43, + "as": 35, + "c": ["R"], + "z": "B", + "t": timestamp, + }, + ) + assert type(quote) == Quote + assert quote.symbol == "SPIP" + assert quote.bid_price == 25.41 + assert quote.ask_size == 35 + assert quote.conditions == ["R"] + + orderbook = ws_client._cast( + { + "T": "o", + "S": "BTC/USD", + "t": timestamp, + "b": [{"p": 65128.1, "s": 1.6542}], + "a": [{"p": 65128.1, "s": 1.6542}], + }, + ) + assert type(orderbook) == Orderbook + assert orderbook.symbol == "BTC/USD" + assert orderbook.bids == [OrderbookQuote(p=65128.1, s=1.6542)] + + trading_status = ws_client._cast( + { + "T": "s", + "S": "STRR", + "t": timestamp, + "sc": "T", + "sm": "Trading Resumption", + "rc": "C11", + "rm": "", + "z": "C", + }, + ) + assert type(trading_status) == TradingStatus + assert trading_status.status_code == "T" + + cancel = ws_client._cast( + { + "T": "x", + "S": "DJT", + "i": 4868, + "x": "D", + "p": 36.18, + "s": 31800, + "a": "C", + "z": "C", + "t": timestamp, + }, + ) + assert type(cancel) == TradeCancel + assert cancel.id == 4868 + assert cancel.exchange == "D" + assert cancel.price == 36.18 + + created_at = datetime(2024, 6, 17, 14, 11, 0, tzinfo=utc) + news = ws_client._cast( + { + "T": "n", + "id": 39358670, + "headline": "Broadcom shares are trading higher. The company last week reported better-than-expected Q2 financial results, issued strong revenue guidance and announced a 10-for-1 forward split.", + "summary": "", + "author": "Benzinga Newsdesk", + "created_at": Timestamp.from_datetime(created_at), + "updated_at": Timestamp.from_datetime(created_at), + "url": "https://www.benzinga.com/wiim/24/06/39358670/broadcom-shares-are-trading-higher-the-company-last-week-reported-better-than-expected-q2-financial", + "content": "", + "symbols": ["AVGO"], + "source": "benzinga", + }, + ) + assert type(news) == News + assert news.id == 39358670 + assert news.symbols == ["AVGO"] + assert news.created_at == created_at # Raw Client + raw_bar = raw_ws_client._cast( + { + "T": "b", + "S": "AAPL", + "o": 177.94, + "c": 178.005, + "h": 178.005, + "l": 177.94, + "v": 8547, + "t": timestamp, + "n": 66, + "vw": 177.987562, + }, + ) + assert type(raw_bar) == dict + assert raw_bar["S"] == "AAPL" + assert raw_bar["h"] == 178.005 - # Bar - raw_bar_msg_type = "b" - raw_bar_msg_dict = { - "S": "AAPL", - "o": 177.94, - "c": 178.005, - "h": 178.005, - "l": 177.94, - "v": 8547, - "t": timestamp, - "n": 66, - "vw": 177.987562, - } - raw_bar_cast_msg = raw_ws_client._cast(raw_bar_msg_type, raw_bar_msg_dict) - - assert type(raw_bar_cast_msg) == dict - - assert raw_bar_cast_msg["S"] == "AAPL" - assert raw_bar_cast_msg["h"] == 178.005 - - # Trade - raw_trade_msg_type = "t" - raw_trade_msg_dict = { - "T": "t", - "S": "AAPL", - "i": 6142, - "x": "V", - "p": 177.79, - "s": 90, - "c": ["@", "I"], - "z": "C", - "t": timestamp, - } +@pytest.mark.asyncio +async def test_dispatch(ws_client: DataStream, timestamp: Timestamp): + articles_a, articles_b, articles_star = [], [], [] + + async def handler_a(d): + articles_a.append(d) - raw_trade_cast_msg = raw_ws_client._cast(raw_trade_msg_type, raw_trade_msg_dict) + async def handler_b(d): + articles_b.append(d) - assert type(raw_trade_cast_msg) == dict + async def handler_star(d): + articles_star.append(d) - assert raw_trade_cast_msg["S"] == "AAPL" - assert raw_trade_cast_msg["p"] == 177.79 - assert raw_trade_cast_msg["x"] == "V" + ws_client._subscribe(handler_a, ("A",), ws_client._handlers["news"]) + ws_client._subscribe(handler_b, ("B",), ws_client._handlers["news"]) + ws_client._subscribe(handler_star, ("*",), ws_client._handlers["news"]) - # News - raw_news_msg_type = "n" - raw_news_msg_dict = { + msg_a = { "T": "n", - "id": 24918784, - "headline": "Corsair Reports Purchase Of Majority Ownership In iDisplay, No Terms Disclosed", - "summary": "Corsair Gaming, Inc. (NASDAQ:CRSR) (“Corsair”), a leading global provider and innovator of high-performance gear for gamers and content creators, today announced that it acquired a 51% stake in iDisplay", - "author": "Benzinga Newsdesk", + "author": "benzinga", + "headline": "a", + "id": 1, + "summary": "a", + "content": "", + "url": "url", + "source": "benzinga", + "symbols": ["A"], "created_at": timestamp, "updated_at": timestamp, - "url": "https://www.benzinga.com/m-a/22/01/24918784/corsair-reports-purchase-of-majority-ownership-in-idisplay-no-terms-disclosed", - "content": '\u003cp\u003eCorsair Gaming, Inc. (NASDAQ:\u003ca class="ticker" href="https://www.benzinga.com/stock/CRSR#NASDAQ"\u003eCRSR\u003c/a\u003e) (\u0026ldquo;Corsair\u0026rdquo;), a leading global ...', - "symbols": ["CRSR"], - "source": "benzinga", } - - raw_news_cast_msg = raw_ws_client._cast(raw_news_msg_type, raw_news_msg_dict) - - assert type(raw_news_cast_msg) == dict - - assert "CRSR" in raw_news_cast_msg["symbols"] - assert raw_news_cast_msg["source"] == "benzinga" - assert ( - raw_news_cast_msg["headline"] - == "Corsair Reports Purchase Of Majority Ownership In iDisplay, No Terms Disclosed" - ) + await ws_client._dispatch(msg_a.copy()) + assert len(articles_a) == 1 + assert len(articles_b) == 0 + assert len(articles_star) == 0 + assert type(articles_a[0]) == News + assert articles_a[0].summary == "a" + + msg_b = msg_a.copy() + msg_b["headline"] = "b" + msg_b["symbols"] = ["B", "C"] + await ws_client._dispatch(msg_b.copy()) + assert len(articles_a) == 1 + assert len(articles_b) == 1 + assert len(articles_star) == 1 + assert articles_b[0].headline == "b" + assert articles_star[0].headline == "b" + + msg_c = msg_a.copy() + msg_c["headline"] = "c" + msg_c["symbols"] = ["C"] + await ws_client._dispatch(msg_c.copy()) + assert len(articles_a) == 1 + assert len(articles_b) == 1 + assert len(articles_star) == 2 + assert articles_star[1].headline == "c"