Skip to content

Commit

Permalink
Use TypeVar default to remove overloads (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Feb 2, 2025
1 parent 31d0e57 commit ee082fc
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 319 deletions.
61 changes: 10 additions & 51 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from enum import Enum
from functools import singledispatch
from itertools import groupby
from typing import Any, Generic, TypeVar, cast, overload
from typing import Any, Generic, cast

from typing_extensions import TypeVar

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.base import ChatModel, OutputT, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
Expand Down Expand Up @@ -364,9 +366,6 @@ def _if_given(value: T | None) -> T | anthropic.NotGiven:
return value if value is not None else anthropic.NOT_GIVEN


R = TypeVar("R")


class AnthropicChatModel(ChatModel):
"""An LLM chat model that uses the `anthropic` python package."""

Expand Down Expand Up @@ -428,37 +427,17 @@ def _get_tool_choice(
)
return {"type": "any", "disable_parallel_tool_use": disable_parallel_tool_use}

@overload
def complete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: None = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[str]: ...

@overload
def complete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: Iterable[type[R]] = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[R]: ...

def complete(
self,
messages: Iterable[Message[Any]],
functions: Iterable[Callable[..., Any]] | None = None,
output_types: Iterable[type[R]] | None = None,
output_types: Iterable[type[OutputT]] | None = None,
*,
stop: list[str] | None = None,
) -> AssistantMessage[str] | AssistantMessage[R]:
) -> AssistantMessage[OutputT]:
"""Request an LLM message."""
if output_types is None:
output_types = [] if functions else cast(list[type[R]], [str])
output_types = [] if functions else cast(list[type[OutputT]], [str])

function_schemas = get_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down Expand Up @@ -489,37 +468,17 @@ def complete(
parse_stream(stream, output_types), usage_ref=stream.usage_ref
)

@overload
async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: None = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[str]: ...

@overload
async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: Iterable[type[R]] = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[R]: ...

async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Iterable[Callable[..., Any]] | None = None,
output_types: Iterable[type[R]] | None = None,
output_types: Iterable[type[OutputT]] | None = None,
*,
stop: list[str] | None = None,
) -> AssistantMessage[R] | AssistantMessage[str]:
) -> AssistantMessage[OutputT]:
"""Async version of `complete`."""
if output_types is None:
output_types = [] if functions else cast(list[type[R]], [str])
output_types = [] if functions else cast(list[type[OutputT]], [str])

function_schemas = get_async_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down
100 changes: 32 additions & 68 deletions src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from contextvars import ContextVar
from itertools import chain
from typing import Any, TypeVar, cast, get_origin, overload
from typing import Any, cast, get_origin

from pydantic import ValidationError
from typing_extensions import TypeVar

from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.chat_model.message import AssistantMessage, Message
Expand All @@ -16,7 +17,7 @@
)
from magentic.streaming import AsyncStreamedStr, StreamedStr, achain, async_iter

R = TypeVar("R")
OutputT = TypeVar("OutputT", default=str)

_chat_model_context: ContextVar["ChatModel | None"] = ContextVar(
"chat_model", default=None
Expand Down Expand Up @@ -105,130 +106,93 @@ def __init__(

# TODO: Move this into _parsing
# TODO: Make this a stream class with a close method and context management
def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R:
def parse_stream(
stream: Iterator[Any], output_types: Iterable[type[OutputT]]
) -> OutputT:
"""Parse and validate the LLM output stream against the allowed output types."""
output_type_origins = [get_origin(type_) or type_ for type_ in output_types]
# TODO: option to error/warn/ignore extra objects
# TODO: warn for degenerate output types ?
obj = next(stream)
if isinstance(obj, StreamedStr):
if StreamedResponse in output_type_origins:
return cast(R, StreamedResponse(chain([obj], stream)))
return cast(OutputT, StreamedResponse(chain([obj], stream)))
if StreamedStr in output_type_origins:
return cast(R, obj)
return cast(OutputT, obj)
if str in output_type_origins:
return cast(R, str(obj))
return cast(OutputT, str(obj))
raise StringNotAllowedError(obj.truncate(100))
if isinstance(obj, FunctionCall):
if StreamedResponse in output_type_origins:
return cast(R, StreamedResponse(chain([obj], stream)))
return cast(OutputT, StreamedResponse(chain([obj], stream)))
if ParallelFunctionCall in output_type_origins:
return cast(R, ParallelFunctionCall(chain([obj], stream)))
return cast(OutputT, ParallelFunctionCall(chain([obj], stream)))
if FunctionCall in output_type_origins:
# TODO: Check that FunctionCall type matches ?
return cast(R, obj)
return cast(OutputT, obj)
raise FunctionCallNotAllowedError(obj)
if isinstance(obj, tuple(output_type_origins)):
return cast(R, obj)
return cast(OutputT, obj)
raise ObjectNotAllowedError(obj)


async def aparse_stream(
stream: AsyncIterator[Any], output_types: Iterable[type[R]]
) -> R:
stream: AsyncIterator[Any], output_types: Iterable[type[OutputT]]
) -> OutputT:
"""Async version of `parse_stream`."""
output_type_origins = [get_origin(type_) or type_ for type_ in output_types]
obj = await anext(stream)
if isinstance(obj, AsyncStreamedStr):
if AsyncStreamedResponse in output_type_origins:
return cast(R, AsyncStreamedResponse(achain(async_iter([obj]), stream)))
return cast(
OutputT, AsyncStreamedResponse(achain(async_iter([obj]), stream))
)
if AsyncStreamedStr in output_type_origins:
return cast(R, obj)
return cast(OutputT, obj)
if str in output_type_origins:
return cast(R, await obj.to_string())
return cast(OutputT, await obj.to_string())
raise StringNotAllowedError(await obj.truncate(100))
if isinstance(obj, FunctionCall):
if AsyncStreamedResponse in output_type_origins:
return cast(R, AsyncStreamedResponse(achain(async_iter([obj]), stream)))
return cast(
OutputT, AsyncStreamedResponse(achain(async_iter([obj]), stream))
)
if AsyncParallelFunctionCall in output_type_origins:
return cast(R, AsyncParallelFunctionCall(achain(async_iter([obj]), stream)))
return cast(
OutputT, AsyncParallelFunctionCall(achain(async_iter([obj]), stream))
)
if FunctionCall in output_type_origins:
return cast(R, obj)
return cast(OutputT, obj)
raise FunctionCallNotAllowedError(obj)
if isinstance(obj, tuple(output_type_origins)):
return cast(R, obj)
return cast(OutputT, obj)
raise ObjectNotAllowedError(obj)


class ChatModel(ABC):
"""An LLM chat model."""

@overload
@abstractmethod
def complete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: None = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[str]: ...

@overload
@abstractmethod
def complete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: Iterable[type[R]] = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[R]: ...

@abstractmethod
def complete(
self,
messages: Iterable[Message[Any]],
functions: Iterable[Callable[..., Any]] | None = None,
# TODO: Set default of R to str in Python 3.13
output_types: Iterable[type[R | str]] | None = None,
output_types: Iterable[type[OutputT]] | None = None,
*,
stop: list[str] | None = None,
) -> AssistantMessage[str] | AssistantMessage[R]:
) -> AssistantMessage[OutputT]:
"""Request an LLM message."""
...

@overload
@abstractmethod
async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: None = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[str]: ...

@overload
@abstractmethod
async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Any = ...,
output_types: Iterable[type[R]] = ...,
*,
stop: list[str] | None = ...,
) -> AssistantMessage[R]: ...

@abstractmethod
async def acomplete(
self,
messages: Iterable[Message[Any]],
functions: Iterable[Callable[..., Any]] | None = None,
output_types: Iterable[type[R | str]] | None = None,
output_types: Iterable[type[OutputT]] | None = None,
*,
stop: list[str] | None = None,
) -> AssistantMessage[str] | AssistantMessage[R]:
) -> AssistantMessage[OutputT]:
"""Async version of `complete`."""
...

Expand Down
Loading

0 comments on commit ee082fc

Please sign in to comment.