Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groq chat vanna integration #757

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }}
SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }}
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client", "groq"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -57,3 +57,4 @@ pgvector = ["langchain-postgres>=0.0.12"]
faiss-cpu = ["faiss-cpu"]
faiss-gpu = ["faiss-gpu"]
xinference-client = ["xinference-client"]
groq = ["groq"]
1 change: 1 addition & 0 deletions src/vanna/groq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .groq_chat import Groq_Chat
116 changes: 116 additions & 0 deletions src/vanna/groq/groq_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os

from groq import Groq

from ..base import VannaBase


class Groq_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

# default parameters - can be overridden using config
self.temperature = 0.7

if "temperature" in config:
self.temperature = config["temperature"]

if "model" in config:
model = config["model"]

if client is not None:
self.client = client
return

if config is None and client is None:
self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
return

if "api_key" in config:
self.client = Groq(api_key=config["api_key"])

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")

if len(prompt) == 0:
raise Exception("Prompt is empty")

# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4

if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "llama-3.1-8b-instant"
else:
model = "llama3-8b-8192"

print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)

# Find the first response from the chatbot that has text in it (some responses may not have text)
for choice in response.choices:
if "text" in choice:
return choice.text

# If no response with text is found, return the first response's content (which may be empty)
return response.choices[0].message.content
14 changes: 14 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.remote import VannaDefault
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
from vanna.groq import Groq_Chat

try:
print("Trying to load .env")
Expand Down Expand Up @@ -241,3 +242,16 @@ def test_training_plan():

plan = vn_dummy.get_training_plan_generic(df_information_schema)
assert len(plan._plan) == 8

class VannaGroq(VannaDB_VectorStore, Groq_Chat):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
Groq_Chat.__init__(self, config=config)

vn_groq = VannaGroq(config={'api_key': os.environ['GROQ_API_KEY']})
vn_groq.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_groq():
sql = vn_groq.generate_sql("What are the top 10 customers by sales?")
df = vn_groq.run_sql(sql)
assert len(df) == 10