Skip to content

INTPYTHON-667 Support Azure OpenAI in tests #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions libs/langchain-mongodb/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from pymongo import MongoClient

from ..utils import CONNECTION_STRING
Expand Down Expand Up @@ -34,12 +34,14 @@ def embedding() -> Embeddings:
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
model="text-embedding-3-small",
)
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
return AzureOpenAIEmbeddings(model="text-embedding-3-small")

return OllamaEmbeddings(model="all-minilm:l6-v2")


@pytest.fixture(scope="session")
def dimensions() -> int:
if os.environ.get("OPENAI_API_KEY"):
if os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_OPENAI_ENDPOINT"):
return 1536
return 384
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import requests
from flaky import flaky # type:ignore[import-untyped]
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langgraph.prebuilt import create_react_agent
from pymongo import MongoClient

Expand Down Expand Up @@ -47,20 +47,24 @@ def db(client: MongoClient) -> MongoDBDatabase:

@flaky(max_runs=5, min_passes=4)
@pytest.mark.skipif(
"OPENAI_API_KEY" not in os.environ, reason="test requires OpenAI for chat responses"
"OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_ENDPOINT" not in os.environ,
reason="test requires OpenAI for chat responses",
)
def test_toolkit_response(db):
db_wrapper = MongoDBDatabase.from_connection_string(
CONNECTION_STRING, database=DB_NAME
)
llm = ChatOpenAI(model="gpt-4o-mini", timeout=60)
if "AZURE_OPENAI_ENDPOINT" in os.environ:
llm = AzureChatOpenAI(model="gpt-4o-mini", timeout=60)
else:
llm = ChatOpenAI(model="gpt-4o-mini", timeout=60)

toolkit = MongoDBDatabaseToolkit(db=db_wrapper, llm=llm)

system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)
prompt = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)

test_query = "Which country's customers spent the most?"
agent = create_react_agent(llm, toolkit.get_tools(), state_modifier=system_message)
agent = create_react_agent(llm, toolkit.get_tools(), prompt=prompt)
agent.step_timeout = 60
events = agent.stream(
{"messages": [("user", test_query)]},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai.chat_models.base import BaseChatOpenAI
from pymongo import MongoClient
from pymongo.collection import Collection

Expand Down Expand Up @@ -50,7 +51,7 @@ def collection(client: MongoClient) -> Collection:


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
not os.environ.get("OPENAI_API_KEY") and "AZURE_OPENAI_ENDPOINT" not in os.environ,
reason="Requires OpenAI for chat responses.",
)
def test_chain(
Expand Down Expand Up @@ -120,7 +121,10 @@ def test_chain(
"""
prompt = ChatPromptTemplate.from_template(template)

model = ChatOpenAI()
if "AZURE_OPENAI_ENDPOINT" in os.environ:
model: BaseChatOpenAI = AzureChatOpenAI(model="o4-mini")
else:
model = ChatOpenAI()

chain = (
{"context": retriever, "question": RunnablePassthrough()} # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai.chat_models.base import BaseChatOpenAI

from langchain_mongodb import MongoDBAtlasVectorSearch, index
from langchain_mongodb.retrievers import MongoDBAtlasSelfQueryRetriever
Expand All @@ -17,7 +18,7 @@
COLLECTION_NAME = "test_self_querying_retriever"
TIMEOUT = 120

if "OPENAI_API_KEY" not in os.environ:
if "OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_ENDPOINT" not in os.environ:
pytest.skip("Requires OpenAI for chat responses.", allow_module_level=True)


Expand Down Expand Up @@ -161,8 +162,10 @@ def vectorstore(


@pytest.fixture
def llm() -> ChatOpenAI:
def llm() -> BaseChatOpenAI:
"""Model used for interpreting query."""
if "AZURE_OPENAI_ENDPOINT" in os.environ:
return AzureChatOpenAI(model="gpt-4o", temperature=0.0, cache=False)
return ChatOpenAI(model="gpt-4o", temperature=0.0, cache=False)


Expand Down
4 changes: 3 additions & 1 deletion libs/langchain-mongodb/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from pydantic import model_validator
from pymongo import MongoClient
from pymongo.collection import Collection
Expand All @@ -46,6 +46,8 @@ def create_database() -> MongoDBDatabase:


def create_llm() -> BaseChatModel:
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
return AzureChatOpenAI(model="o4-mini", timeout=60, cache=False)
if os.environ.get("OPENAI_API_KEY"):
return ChatOpenAI(model="gpt-4o-mini", timeout=60, cache=False)
return ChatOllama(model="llama3:8b", cache=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
import pytest
from langchain_core.embeddings import Embeddings
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings


@pytest.fixture(scope="session")
def embedding() -> Embeddings:
if os.environ.get("AZURE_OPENAI_ENDPOINT"):
return AzureOpenAIEmbeddings(model="text-embedding-3-small")
if os.environ.get("OPENAI_API_KEY"):
return OpenAIEmbeddings(
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
model="text-embedding-3-small",
)

return OllamaEmbeddings(model="all-minilm:l6-v2")


@pytest.fixture(scope="session")
def dimensions() -> int:
if os.environ.get("OPENAI_API_KEY"):
if os.environ.get("OPENAI_API_KEY") or os.environ.get("AZURE_OPENAI_ENDPOINT"):
return 1536
return 384
Loading