Skip to content

Commit

Permalink
Add news to sentiment
Browse files Browse the repository at this point in the history
  • Loading branch information
virattt committed Jan 19, 2025
1 parent 8135a0f commit 76980fe
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 37 deletions.
23 changes: 16 additions & 7 deletions src/agents/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import json

from tools.api import get_insider_trades
from tools.api import get_insider_trades, get_company_news


##### Sentiment Agent #####
Expand All @@ -28,16 +28,25 @@ def sentiment_agent(state: AgentState):
limit=1000,
)

if not insider_trades:
progress.update_status("sentiment_agent", ticker, "Failed: No insider trades found")
continue

progress.update_status("sentiment_agent", ticker, "Analyzing trading patterns")

# Get the signals from the insider trades
transaction_shares = pd.Series([t.transaction_shares for t in insider_trades]).dropna()
bearish_condition = transaction_shares < 0
signals = np.where(bearish_condition, "bearish", "bullish").tolist()
insider_signals = np.where(transaction_shares < 0, "bearish", "bullish").tolist()

progress.update_status("sentiment_agent", ticker, "Fetching company news")

# Get the company news
company_news = get_company_news(ticker, end_date, limit=100)

# Get the sentiment from the company news
sentiment = pd.Series([n.sentiment for n in company_news]).dropna()
news_signals = np.where(sentiment == "negative", "bearish",
np.where(sentiment == "positive", "bullish", "neutral")).tolist()

progress.update_status("sentiment_agent", ticker, "Combining signals")
# Combine signals from both sources
signals = insider_signals + news_signals

# Determine overall signal
bullish_signals = signals.count("bullish")
Expand Down
16 changes: 10 additions & 6 deletions src/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from utils.analysts import ANALYST_ORDER
from main import run_hedge_fund
from tools.api import (
get_company_news,
get_price_data,
get_prices,
get_financial_metrics,
Expand Down Expand Up @@ -60,8 +61,11 @@ def prefetch_data(self):
# Fetch financial metrics
get_financial_metrics(ticker, self.end_date, limit=10)

# Fetch insider trades
get_insider_trades(ticker, self.end_date, limit=1000)
# Fetch insider trades for the entire period
get_insider_trades(ticker, self.end_date, start_date=self.start_date, limit=1000)

# Fetch company news for the entire period
get_company_news(ticker, self.end_date, start_date=self.start_date, limit=1000)

# Fetch common line items used by valuation agent
search_line_items(
Expand Down Expand Up @@ -331,7 +335,7 @@ def analyze_performance(self):
parser.add_argument(
"--tickers",
type=str,
required=True,
required=False,
help="Comma-separated list of stock ticker symbols (e.g., AAPL,MSFT,GOOGL)",
)
parser.add_argument(
Expand All @@ -343,7 +347,7 @@ def analyze_performance(self):
parser.add_argument(
"--start-date",
type=str,
default=(datetime.now() - relativedelta(months=3)).strftime("%Y-%m-%d"),
default=(datetime.now() - relativedelta(months=12)).strftime("%Y-%m-%d"),
help="Start date in YYYY-MM-DD format",
)
parser.add_argument(
Expand All @@ -356,8 +360,8 @@ def analyze_performance(self):
args = parser.parse_args()

# Parse tickers from comma-separated string
tickers = [ticker.strip() for ticker in args.tickers.split(",")]

# tickers = [ticker.strip() for ticker in args.tickers.split(",")]
tickers = ["AAPL"]
selected_analysts = None
choices = questionary.checkbox(
"Use the Space bar to select/unselect analysts.",
Expand Down
58 changes: 50 additions & 8 deletions src/data/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,80 @@ def __init__(self):
self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {}
self._line_items_cache: dict[str, list[dict[str, any]]] = {}
self._insider_trades_cache: dict[str, list[dict[str, any]]] = {}
self._company_news_cache: dict[str, list[dict[str, any]]] = {}

def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]:
"""Merge existing and new data, avoiding duplicates based on a key field."""
if not existing:
return new_data

# Create a set of existing keys for O(1) lookup
existing_keys = {item[key_field] for item in existing}

# Only add items that don't exist yet
merged = existing.copy()
merged.extend([item for item in new_data if item[key_field] not in existing_keys])
return merged

def get_prices(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached price data if available."""
return self._prices_cache.get(ticker)

def set_prices(self, ticker: str, data: list[dict[str, any]]):
"""Cache price data."""
self._prices_cache[ticker] = data
"""Append new price data to cache."""
self._prices_cache[ticker] = self._merge_data(
self._prices_cache.get(ticker),
data,
key_field="time"
)

def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]:
"""Get cached financial metrics if available."""
return self._financial_metrics_cache.get(ticker)

def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]):
"""Cache financial metrics data."""
self._financial_metrics_cache[ticker] = data
"""Append new financial metrics to cache."""
self._financial_metrics_cache[ticker] = self._merge_data(
self._financial_metrics_cache.get(ticker),
data,
key_field="report_period"
)

def get_line_items(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached line items if available."""
return self._line_items_cache.get(ticker)

def set_line_items(self, ticker: str, data: list[dict[str, any]]):
"""Cache line items data."""
self._line_items_cache[ticker] = data
"""Append new line items to cache."""
self._line_items_cache[ticker] = self._merge_data(
self._line_items_cache.get(ticker),
data,
key_field="report_period"
)

def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached insider trades if available."""
return self._insider_trades_cache.get(ticker)

def set_insider_trades(self, ticker: str, data: list[dict[str, any]]):
"""Cache insider trades data."""
self._insider_trades_cache[ticker] = data
"""Append new insider trades to cache."""
self._insider_trades_cache[ticker] = self._merge_data(
self._insider_trades_cache.get(ticker),
data,
key_field="filing_date" # Could also use transaction_date if preferred
)

def get_company_news(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached company news if available."""
return self._company_news_cache.get(ticker)

def set_company_news(self, ticker: str, data: list[dict[str, any]]):
"""Append new company news to cache."""
self._company_news_cache[ticker] = self._merge_data(
self._company_news_cache.get(ticker),
data,
key_field="date"
)


# Global cache instance
Expand Down
14 changes: 14 additions & 0 deletions src/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ class InsiderTradeResponse(BaseModel):
insider_trades: list[InsiderTrade]


class CompanyNews(BaseModel):
ticker: str
title: str
author: str
source: str
date: str
url: str
sentiment: str | None = None


class CompanyNewsResponse(BaseModel):
news: list[CompanyNews]


class Position(BaseModel):
cash: float = 0.0
shares: int = 0
Expand Down
126 changes: 110 additions & 16 deletions src/tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from data.cache import get_cache
from data.models import (
CompanyNews,
CompanyNewsResponse,
FinancialMetrics,
FinancialMetricsResponse,
Price,
Expand Down Expand Up @@ -32,7 +34,7 @@ def get_prices(ticker: str, start_date: str, end_date: str) -> list[Price]:
if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
headers["X-API-KEY"] = api_key

url = f"https://api.financialdatasets.ai/prices/" f"?ticker={ticker}" f"&interval=day" f"&interval_multiplier=1" f"&start_date={start_date}" f"&end_date={end_date}"
url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}"
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {response.status_code} - {response.text}")
Expand Down Expand Up @@ -69,7 +71,7 @@ def get_financial_metrics(
if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
headers["X-API-KEY"] = api_key

url = f"https://api.financialdatasets.ai/financial-metrics/" f"?ticker={ticker}" f"&report_period_lte={end_date}" f"&limit={limit}" f"&period={period}"
url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}"
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {response.status_code} - {response.text}")
Expand Down Expand Up @@ -134,36 +136,128 @@ def search_line_items(
def get_insider_trades(
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> list[InsiderTrade]:
"""Fetch insider trades from cache or API."""
# Check cache first
if cached_data := _cache.get_insider_trades(ticker):
# Filter cached data by date and limit
filtered_data = [InsiderTrade(**trade) for trade in cached_data if (trade.get("transaction_date") or trade["filing_date"]) <= end_date]
# Sort by transaction_date if available, otherwise filing_date
# Filter cached data by date range
filtered_data = [InsiderTrade(**trade) for trade in cached_data
if (start_date is None or (trade.get("transaction_date") or trade["filing_date"]) >= start_date)
and (trade.get("transaction_date") or trade["filing_date"]) <= end_date]
filtered_data.sort(key=lambda x: x.transaction_date or x.filing_date, reverse=True)
if filtered_data:
return filtered_data[:limit]
return filtered_data

# If not in cache or insufficient data, fetch from API
headers = {}
if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
headers["X-API-KEY"] = api_key

url = f"https://api.financialdatasets.ai/insider-trades/" f"?ticker={ticker}" f"&filing_date_lte={end_date}" f"&limit={limit}"
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {response.status_code} - {response.text}")
data = response.json()
response_model = InsiderTradeResponse(**data)
insider_trades = response_model.insider_trades
if not insider_trades:
all_trades = []
current_end_date = end_date

while True:
url = f"https://api.financialdatasets.ai/insider-trades/?ticker={ticker}&filing_date_lte={current_end_date}"
if start_date:
url += f"&filing_date_gte={start_date}"
url += f"&limit={limit}"

response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {response.status_code} - {response.text}")

data = response.json()
response_model = InsiderTradeResponse(**data)
insider_trades = response_model.insider_trades

if not insider_trades:
break

all_trades.extend(insider_trades)

# Only continue pagination if we have a start_date and got a full page
if not start_date or len(insider_trades) < limit:
break

# Update end_date to the oldest filing date from current batch for next iteration
current_end_date = min(trade.filing_date for trade in insider_trades).split('T')[0]

# If we've reached or passed the start_date, we can stop
if current_end_date <= start_date:
break

if not all_trades:
return []

# Cache the results
_cache.set_insider_trades(ticker, [trade.model_dump() for trade in insider_trades])
return insider_trades[:limit]
_cache.set_insider_trades(ticker, [trade.model_dump() for trade in all_trades])
return all_trades


def get_company_news(
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> list[CompanyNews]:
"""Fetch company news from cache or API."""
# Check cache first
if cached_data := _cache.get_company_news(ticker):
# Filter cached data by date range
filtered_data = [CompanyNews(**news) for news in cached_data
if (start_date is None or news["date"] >= start_date)
and news["date"] <= end_date]
filtered_data.sort(key=lambda x: x.date, reverse=True)
if filtered_data:
return filtered_data

# If not in cache or insufficient data, fetch from API
headers = {}
if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
headers["X-API-KEY"] = api_key

all_news = []
current_end_date = end_date

while True:
url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}"
if start_date:
url += f"&start_date={start_date}"
url += f"&limit={limit}"

response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {response.status_code} - {response.text}")

data = response.json()
response_model = CompanyNewsResponse(**data)
company_news = response_model.news

if not company_news:
break

all_news.extend(company_news)

# Only continue pagination if we have a start_date and got a full page
if not start_date or len(company_news) < limit:
break

# Update end_date to the oldest date from current batch for next iteration
current_end_date = min(news.date for news in company_news).split('T')[0]

# If we've reached or passed the start_date, we can stop
if current_end_date <= start_date:
break

if not all_news:
return []

# Cache the results
_cache.set_company_news(ticker, [news.model_dump() for news in all_news])
return all_news



def get_market_cap(
Expand Down

0 comments on commit 76980fe

Please sign in to comment.