Skip to content

Add built-in support for tool control parameters #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/_core_features/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ puts response.content
# => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast."
```

### Tool Choice Control

Control when and how tools are called using `choice` and `parallel` options:

```ruby
chat = RubyLLM.chat(model: 'gpt-4o')

# Choice options
chat.with_tool(Weather, choice: :auto) # Model decides whether to call any provided tools or not (default)
chat.with_tool(Weather, choice: :any) # Model must use one of the provided tools
chat.with_tool(Weather, choice: :none) # No tools
chat.with_tool(Weather, choice: :weather) # Force specific tool

# Parallel tool calls
chat.with_tools(Weather, Calculator, parallel: true) # Model can output multiple tool calls at once (default)
chat.with_tools(Weather, Calculator, parallel: false) # At most one tool call
```

> With `:any` or specific tool choices, tool results are not automatically sent back to the AI model (see The Tool Execution Flow section below) to prevent infinite loops.
{: .note }

### Model Compatibility
{: .d-inline-block }

Expand Down
44 changes: 37 additions & 7 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module RubyLLM
class Chat
include Enumerable

attr_reader :model, :messages, :tools, :params, :headers, :schema
attr_reader :model, :messages, :tools, :tool_choice, :parallel_tool_calls, :params, :headers, :schema

def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
Expand All @@ -23,6 +23,8 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
model_id = model || @config.default_model
with_model(model_id, provider: provider, assume_exists: assume_model_exists)
@temperature = 0.7
@tool_choice = nil
@parallel_tool_calls = nil
@messages = []
@tools = {}
@params = {}
Expand Down Expand Up @@ -50,15 +52,19 @@ def with_instructions(instructions, replace: false)
self
end

def with_tool(tool)
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
def with_tool(tool, choice: nil, parallel: nil)
unless tool.nil?
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
end
update_tool_options(choice:, parallel:)
self
end

def with_tools(*tools, replace: false)
def with_tools(*tools, replace: false, choice: nil, parallel: nil)
@tools.clear if replace
tools.compact.each { |tool| with_tool tool }
update_tool_options(choice:, parallel:)
self
end

Expand Down Expand Up @@ -136,6 +142,8 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
params: @params,
headers: @headers,
schema: @schema,
tool_choice: @tool_choice,
parallel_tool_calls: @parallel_tool_calls,
&wrap_streaming_block(&)
)

Expand Down Expand Up @@ -189,7 +197,7 @@ def wrap_streaming_block(&block)
end
end

def handle_tool_calls(response, &)
def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = nil

response.tool_calls.each_value do |tool_call|
Expand All @@ -203,7 +211,9 @@ def handle_tool_calls(response, &)
halt_result = result if result.is_a?(Tool::Halt)
end

halt_result || complete(&)
return halt_result if halt_result

should_continue_after_tools? ? complete(&) : response
end

def execute_tool(tool_call)
Expand All @@ -212,6 +222,26 @@ def execute_tool(tool_call)
tool.call(args)
end

def update_tool_options(choice:, parallel:)
unless choice.nil?
valid_tool_choices = %i[auto none any] + tools.keys
unless valid_tool_choices.include?(choice.to_sym)
raise InvalidToolChoiceError,
"Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}"
end

@tool_choice = choice.to_sym
end

@parallel_tool_calls = !!parallel unless parallel.nil?
end

def should_continue_after_tools?
# Continue conversation only with :auto tool choice to avoid infinite loops.
# With :any or specific tool choices, the model would keep calling tools repeatedly.
tool_choice.nil? || tool_choice == :auto
end

def instance_variables
super - %i[@connection @config]
end
Expand Down
1 change: 1 addition & 0 deletions lib/ruby_llm/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def initialize(response = nil, message = nil)
# Error classes for non-HTTP errors
class ConfigurationError < StandardError; end
class InvalidRoleError < StandardError; end
class InvalidToolChoiceError < StandardError; end
class ModelNotFoundError < StandardError; end
class UnsupportedAttachmentError < StandardError; end

Expand Down
7 changes: 6 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ def configuration_requirements
self.class.configuration_requirements
end

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil,
tool_choice: nil, parallel_tool_calls: nil, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = Utils.deep_merge(
params,
render_payload(
messages,
tools: tools,
tool_choice: tool_choice,
parallel_tool_calls: parallel_tool_calls,
temperature: normalized_temperature,
model: model,
stream: block_given?,
Expand All @@ -61,6 +65,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
sync_response @connection, payload, headers
end
end
# rubocop:enable Metrics/ParameterLists

def list_models
response = @connection.get models_url
Expand Down
15 changes: 11 additions & 4 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)

build_base_payload(chat_messages, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:, temperature:)
add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
Expand All @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream)
}
end

def add_optional_fields(payload, system_content:, tools:, temperature:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:) # rubocop:disable Metrics/ParameterLists
if tools.any?
payload[:tools] = tools.values.map { |t| Tools.function_for(t) }
payload[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls) unless tool_choice.nil?
end

payload[:system] = system_content unless system_content.empty?
payload[:temperature] = temperature unless temperature.nil?
end
Expand Down
9 changes: 9 additions & 0 deletions lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def clean_parameters(parameters)
def required_parameters(parameters)
parameters.select { |_, param| param.required }.keys
end

def build_tool_choice(tool_choice, parallel_tool_calls)
{
type: %i[auto any none].include?(tool_choice) ? tool_choice : :tool
}.tap do |tc|
tc[:name] = tool_choice if tc[:type] == :tool
tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil?
end
end
end
end
end
Expand Down
8 changes: 6 additions & 2 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ def completion_url
"model/#{@model_id}/invoke"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
# Hold model_id in instance variable for use in completion_url and stream_url
@model_id = model

system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
system_content = Anthropic::Chat.build_system_content(system_messages)

build_base_payload(chat_messages, model).tap do |payload|
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:)
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_choice:,
parallel_tool_calls:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def build_base_payload(chat_messages, model)
{
Expand Down
12 changes: 10 additions & 2 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def completion_url
"models/#{@model}:generateContent"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
@model = model # Store model for completion_url/stream_url
payload = {
contents: format_messages(messages),
Expand All @@ -25,9 +27,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema)
end

payload[:tools] = format_tools(tools) if tools.any?
if tools.any?
payload[:tools] = format_tools(tools)
# Gemini doesn't support controlling parallel tool calls
payload[:toolConfig] = build_tool_config(tool_choice) unless tool_choice.nil?
end

payload
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

private

Expand Down
15 changes: 15 additions & 0 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ def param_type_for_gemini(type)
else 'STRING'
end
end

def build_tool_config(tool_choice)
{
functionCallingConfig: {
mode: specific_tool_choice?(tool_choice) ? 'any' : tool_choice
}.tap do |config|
# Use allowedFunctionNames to simulate specific tool choice
config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice)
end
}
end

def specific_tool_choice?(tool_choice)
!%i[auto none any].include?(tool_choice)
end
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion lib/ruby_llm/providers/mistral/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def format_role(role)
end

# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil)
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, temperature:, model:, stream: false,
schema: nil)
payload = super
# Mistral doesn't support stream_options
payload.delete(:stream_options)
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/openai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def completion_url

module_function

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:,
temperature:, model:, stream: false, schema: nil)
payload = {
model: model,
messages: format_messages(messages),
Expand All @@ -21,7 +23,11 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
# Only include temperature if it's not nil (some models don't accept it)
payload[:temperature] = temperature unless temperature.nil?

payload[:tools] = tools.map { |_, tool| tool_for(tool) } if tools.any?
if tools.any?
payload[:tools] = tools.map { |_, tool| tool_for(tool) }
payload[:tool_choice] = build_tool_choice(tool_choice) unless tool_choice.nil?
payload[:parallel_tool_calls] = parallel_tool_calls unless parallel_tool_calls.nil?
end

if schema
# Use strict mode from schema if specified, default to true
Expand All @@ -40,6 +46,7 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:stream_options] = { include_usage: true } if stream
payload
end
# rubocop:enable Metrics/ParameterLists

def parse_completion_response(response)
data = response.body
Expand Down
16 changes: 16 additions & 0 deletions lib/ruby_llm/providers/openai/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def parse_tool_calls(tool_calls, parse_arguments: true)
]
end
end

def build_tool_choice(tool_choice)
case tool_choice
when :auto, :none
tool_choice
when :any
:required
else
{
type: 'function',
function: {
name: tool_choice
}
}
end
end
end
end
end
Expand Down