Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions libs/partners/deepseek/langchain_deepseek/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,16 @@ class Joke(BaseModel):
default_factory=from_env("DEEPSEEK_API_BASE", default=DEFAULT_API_BASE),
)
"""DeepSeek API base URL"""
strict: bool | None = Field(
default=None,
description=(
"Whether to enable strict mode for function calling. "
"When enabled, uses the Beta API endpoint and ensures "
"outputs strictly comply with the defined JSON schema."
),
)

model_config = ConfigDict(populate_by_name=True)

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand All @@ -198,16 +205,22 @@ def _get_ls_params(
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate necessary environment vars and client params."""
if self.api_base == DEFAULT_API_BASE and not (
# Use Beta API if strict mode is enabled
api_base = self.api_base
if self.strict and self.api_base == DEFAULT_API_BASE:
api_base = "https://api.deepseek.com/beta"

if api_base == DEFAULT_API_BASE and not (
self.api_key and self.api_key.get_secret_value()
):
msg = "If using default api base, DEEPSEEK_API_KEY must be set."
raise ValueError(msg)

client_params: dict = {
k: v
for k, v in {
"api_key": self.api_key.get_secret_value() if self.api_key else None,
"base_url": self.api_base,
"base_url": api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
Expand All @@ -229,6 +242,59 @@ def validate_environment(self) -> Self:
self.async_client = self.root_async_client.chat.completions
return self

def bind_tools(
self,
tools: list,
*,
tool_choice: str | dict | None = None,
strict: bool | None = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model with optional strict mode.

Args:
tools: A list of tool definitions or Pydantic models.
tool_choice: Which tool the model should use.
strict: Whether to enable strict mode for these tools.
If not provided, uses the instance's strict setting.
**kwargs: Additional arguments to pass to the parent method.

Returns:
A Runnable that will call the model with the bound tools.
"""
# Use instance strict setting if not explicitly provided
use_strict = strict if strict is not None else self.strict

# If strict mode is enabled, add strict: true to each tool
if use_strict:
formatted_tools = []
for tool in tools:
# Convert to OpenAI format
from langchain_core.utils.function_calling import convert_to_openai_tool

if not isinstance(tool, dict):
tool_dict = convert_to_openai_tool(tool)
else:
tool_dict = tool.copy()

# Add strict: true to the function definition
if "function" in tool_dict:
tool_dict["function"]["strict"] = True

formatted_tools.append(tool_dict)

tools = formatted_tools

# Add strict to kwargs if it's being used
if use_strict is not None:
kwargs["strict"] = use_strict

return super().bind_tools(
tools,
tool_choice=tool_choice,
**kwargs,
)

def _get_request_payload(
self,
input_: LanguageModelInput,
Expand Down
82 changes: 82 additions & 0 deletions libs/partners/deepseek/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,85 @@ def test_create_chat_result_with_model_provider_multiple_generations(
assert (
generation.message.response_metadata.get("model_provider") == "deepseek"
)


class TestChatDeepSeekStrictMode:
"""Test strict mode functionality."""

def test_strict_mode_uses_beta_api(self) -> None:
"""Test that strict mode switches to Beta API endpoint."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=True,
)

# Check that the client uses the beta endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/beta/"

def test_strict_mode_disabled_uses_default_api(self) -> None:
"""Test that without strict mode, default API is used."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=False,
)

# Check that the client uses the default endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/"

def test_strict_mode_none_uses_default_api(self) -> None:
"""Test that strict=None uses default API."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
)

# Check that the client uses the default endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/"

def test_bind_tools_with_strict_mode(self) -> None:
"""Test that bind_tools adds strict to tool definitions."""
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
"""Get the current weather in a given location."""
location: str = Field(..., description="The city and state") # pyright: ignore[reportUndefinedVariable]

model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=True,
)

# Bind tools
model_with_tools = model.bind_tools([GetWeather])

# Check that tools were bound
assert 'tools' in model_with_tools.kwargs

# Verify that tools have strict property set
tools = model_with_tools.kwargs['tools']
assert len(tools) > 0
assert tools[0]['function']['strict'] is True
def test_bind_tools_override_strict(self) -> None:
"""Test that bind_tools can override instance strict setting."""
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
"""Get the current weather in a given location."""
location: str = Field(..., description="The city and state")

model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=False,
)

# Override with strict=True in bind_tools
model_with_tools = model.bind_tools([GetWeather], strict=True)

# Check that strict was passed to kwargs
assert 'tools' in model_with_tools.kwargs
tools = model_with_tools.kwargs['tools']
assert tools[0]['function']['strict'] is True
43 changes: 31 additions & 12 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,31 +1268,53 @@ def _create_chat_result(
generation_info: dict | None = None,
) -> ChatResult:
generations = []

response_dict = (
response if isinstance(response, dict) else response.model_dump()
)

# Handle response serialization more robustly for non-OpenAI APIs
if isinstance(response, dict):
response_dict = response
else:
# Try model_dump() first
try:
response_dict = response.model_dump()
except Exception as e:
# Fallback: try to access raw JSON if model_dump fails
try:
if hasattr(response, 'model_dump_json'):
import json
response_dict = json.loads(response.model_dump_json())
else:
raise e
except Exception:
# If all else fails, raise the original error
raise e

# Sometimes the AI Model calling will get error, we should raise it (this is
# typically followed by a null value for `choices`, which we raise for
# separately below).
if response_dict.get("error"):
raise ValueError(response_dict.get("error"))

# Raise informative error messages for non-OpenAI chat completions APIs
# that return malformed responses.
try:
choices = response_dict["choices"]
except KeyError as e:
msg = f"Response missing `choices` key: {response_dict.keys()}"
raise KeyError(msg) from e


# Improved null check with better error message
if choices is None:
msg = "Received response with null value for `choices`."
# Provide more debugging info for non-OpenAI APIs
msg = (
f"Received response with null value for `choices`. "
f"Response keys: {list(response_dict.keys())}. "
f"This may indicate an incompatibility with the API endpoint. "
f"Raw response type: {type(response).__name__}"
)
raise TypeError(msg)

token_usage = response_dict.get("usage")
service_tier = response_dict.get("service_tier")

for res in choices:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
Expand All @@ -1319,7 +1341,6 @@ def _create_chat_result(
llm_output["id"] = response_dict["id"]
if service_tier:
llm_output["service_tier"] = service_tier

if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
):
Expand All @@ -1328,9 +1349,7 @@ def _create_chat_result(
generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal

return ChatResult(generations=generations, llm_output=llm_output)

async def _astream(
self,
messages: list[BaseMessage],
Expand Down
60 changes: 55 additions & 5 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

# OpenAI API limits
MAX_TOKENS_PER_REQUEST = 300000 # OpenAI's max tokens per embedding request

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -476,13 +479,37 @@ def _get_len_safe_embeddings(
client_kwargs = {**self._invocation_params, **kwargs}
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: list[list[float]] = []
for i in _iter:
# Calculate actual token counts for each chunk
token_counts = [len(t) if isinstance(t, list) else len(t.split()) for t in tokens]

# Process in batches respecting the token limit
i = 0
while i < len(tokens):
# Determine how many chunks we can include in this batch
batch_token_count = 0
batch_end = i

for j in range(i, min(i + _chunk_size, len(tokens))):
chunk_tokens = token_counts[j]
# Check if adding this chunk would exceed the limit
if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
if batch_end == i:
# Single chunk exceeds limit - handle it anyway
batch_end = j + 1
break
batch_token_count += chunk_tokens
batch_end = j + 1

# Make API call with this batch
batch_tokens = tokens[i:batch_end]
response = self.client.create(
input=tokens[i : i + _chunk_size], **client_kwargs
input=batch_tokens, **client_kwargs
)
if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])

i = batch_end

embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty
Expand Down Expand Up @@ -530,14 +557,37 @@ async def _aget_len_safe_embeddings(
None, self._tokenize, texts, _chunk_size
)
batched_embeddings: list[list[float]] = []
for i in range(0, len(tokens), _chunk_size):
# Calculate actual token counts for each chunk
token_counts = [len(t) if isinstance(t, list) else len(t.split()) for t in tokens]

# Process in batches respecting the token limit
i = 0
while i < len(tokens):
# Determine how many chunks we can include in this batch
batch_token_count = 0
batch_end = i

for j in range(i, min(i + _chunk_size, len(tokens))):
chunk_tokens = token_counts[j]
# Check if adding this chunk would exceed the limit
if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
if batch_end == i:
# Single chunk exceeds limit - handle it anyway
batch_end = j + 1
break
batch_token_count += chunk_tokens
batch_end = j + 1

# Make API call with this batch
batch_tokens = tokens[i:batch_end]
response = await self.async_client.create(
input=tokens[i : i + _chunk_size], **client_kwargs
input=batch_tokens, **client_kwargs
)

if not isinstance(response, dict):
response = response.model_dump()
batched_embeddings.extend(r["embedding"] for r in response["data"])

i = batch_end

embeddings = _process_batched_chunked_embeddings(
len(texts), tokens, batched_embeddings, indices, self.skip_empty
Expand Down
Loading