Skip to content

[WIP] feat: add reasoning flag in dspy.LM #7994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ def __init__(
num_retries: int = 8,
provider=None,
finetuning_model: Optional[str] = None,
reasoning_model: bool = False,
launch_kwargs: Optional[dict[str, Any]] = None,
train_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
@@ -65,6 +66,7 @@ def __init__(
provider: The provider to use. If not specified, the provider will be inferred from the model.
finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
from the models available for inference.
reasoning_model: Whether the model is a reasoning model, which will be used when interacting with ChainOfThought module.
"""
# Remember to update LM.copy() if you modify the constructor!
self.model = model
@@ -77,6 +79,7 @@ def __init__(
self.callbacks = callbacks or []
self.num_retries = num_retries
self.finetuning_model = finetuning_model
self.reasoning_model = reasoning_model
self.launch_kwargs = launch_kwargs or {}
self.train_kwargs = train_kwargs or {}

@@ -91,7 +94,9 @@ def __init__(
assert (
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
self.kwargs = dict(
temperature=temperature, max_completion_tokens=max_tokens, **kwargs
)
else:
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

@@ -106,14 +111,22 @@ def forward(self, prompt=None, messages=None, **kwargs):

# Make the request and handle LRU & disk caching.
if cache_in_memory:
completion = cached_litellm_completion if self.model_type == "chat" else cached_litellm_text_completion
completion = (
cached_litellm_completion
if self.model_type == "chat"
else cached_litellm_text_completion
)

return completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
else:
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
completion = (
litellm_completion
if self.model_type == "chat"
else litellm_text_completion
)

return completion(
request=dict(model=self.model, messages=messages, **kwargs),
@@ -228,7 +241,11 @@ def transform_value(value):
return value.model_json_schema()
elif isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
elif (
callable(value)
and hasattr(value, "__code__")
and hasattr(value.__code__, "co_code")
):
return value.__code__.co_code.decode("utf-8")
else:
# Note: We don't attempt to compute a hash of the value, since the default
@@ -275,7 +292,11 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
)


def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
def litellm_completion(
request: Dict[str, Any],
num_retries: int,
cache={"no-cache": True, "no-store": True},
):
retry_kwargs = dict(
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
@@ -321,7 +342,11 @@ def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
)


def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
def litellm_text_completion(
request: Dict[str, Any],
num_retries: int,
cache={"no-cache": True, "no-store": True},
):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = request.pop("model").split("/", 1)
@@ -332,7 +357,9 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
prompt = "\n\n".join(
[x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]
)

return litellm.text_completion(
cache=cache,
12 changes: 11 additions & 1 deletion dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dspy
from dspy.clients.base_lm import BaseLM
from dspy.dsp.utils import settings
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature

@@ -13,8 +15,16 @@ def __init__(self, signature, rationale_type=None, **config):
desc = "${reasoning}"
rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)


self.plain_predict = dspy.Predict(signature, **config)
self.predict = dspy.Predict(extended_signature, **config)

def forward(self, **kwargs):
# Keep same logic with `dspy.Predict`
lm = kwargs.pop("lm", getattr(self, "lm", None)) or settings.lm
assert isinstance(lm, BaseLM), "No LM is loaded."

# Custom models that subclassing `BaseLM` don't have this parameter
if getattr(lm, "reasoning_model", False):
return self.plain_predict(**kwargs)
return self.predict(**kwargs)
12 changes: 12 additions & 0 deletions tests/predict/test_chain_of_thought.py
Original file line number Diff line number Diff line change
@@ -12,3 +12,15 @@ def test_initialization_with_string_signature():
"answer",
]
assert predict(question="What is 1+1?").answer == "2"


def test_cot_skips_with_reasoning_model():
lm = DummyLM([{"answer": "2"}])
lm.reasoning_model = True
dspy.settings.configure(lm=lm)
signature = dspy.Signature("question -> answer")
predict = ChainOfThought(signature)
assert list(predict.plain_predict.signature.output_fields.keys()) == [
"answer",
]
assert predict(question="What is 1+1?").answer == "2"