Skip to content

Commit

Permalink
Merge pull request #970 from WolframResearch/feature/partial-service-…
Browse files Browse the repository at this point in the history
…tool-calling

Hybrid text/service tool calling method
  • Loading branch information
rhennigan authored Dec 4, 2024
2 parents 096c9f5 + c2d3d77 commit 34ff06f
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 31 deletions.
5 changes: 3 additions & 2 deletions Source/Chatbook/ChatMessages.wl
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,9 @@ getToolPrompt // endDefinition;
(*toolPromptData*)
toolPromptData // beginDefinition;

toolPromptData[ args: KeyValuePattern @ { "Tools" -> tools_List } ] :=
Append[ args, "Tools" -> Replace[ tools, t_LLMTool :> TemplateVerbatim @ t, { 1 } ] ];
(* What was this definition doing? It was causing template errors, so it's now disabled: *)
(* toolPromptData[ args: KeyValuePattern @ { "Tools" -> tools_List } ] :=
Append[ args, "Tools" -> Replace[ tools, t_LLMTool :> TemplateVerbatim @ t, { 1 } ] ]; *)

toolPromptData[ expr_ ] := expr;

Expand Down
1 change: 1 addition & 0 deletions Source/Chatbook/ChatState.wl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ withChatState[ eval_ ] :=
$openToolCallBoxes = Automatic,
$progressContainer = None,
$showProgressText = $showProgressText,
$receivedToolCall = False,

(* Values used for token budgets during cell serialization: *)
$cellStringBudget = $cellStringBudget,
Expand Down
3 changes: 3 additions & 0 deletions Source/Chatbook/CommonSymbols.wl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ BeginPackage[ "Wolfram`Chatbook`Common`" ];
`$progressContainer;
`$progressText;
`$progressWidth;
`$receivedToolCall;
`$resultCellCache;
`$rightSelectionIndicator;
`$sandboxKernelCommandLine;
Expand All @@ -77,6 +78,7 @@ BeginPackage[ "Wolfram`Chatbook`Common`" ];
`$showProgressBar;
`$showProgressCancelButton;
`$showProgressText;
`$simpleToolMethod;
`$statelessProgressIndicator;
`$suppressButtonAppearance;
`$timingLog;
Expand Down Expand Up @@ -301,6 +303,7 @@ BeginPackage[ "Wolfram`Chatbook`Common`" ];
`toolRequestParser;
`toolSelectedQ;
`toolsEnabledQ;
`toolShortName;
`topParentCell;
`toSmallSettings;
`trackedDynamic;
Expand Down
99 changes: 91 additions & 8 deletions Source/Chatbook/LLMUtilities.wl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,11 @@ extractBodyChunks // beginDefinition;

extractBodyChunks[ data_ ] := Enclose[
Catch[
ConfirmMatch[ DeleteCases[ Flatten @ { extractBodyChunks0 @ data }, "" ], { ___String }, "Result" ],
ConfirmMatch[
DeleteCases[ Flatten @ { extractBodyChunks0 @ data }, "" ],
{ (_String|_LLMToolRequest)... },
"Result"
],
$bodyChunksTag
],
throwInternalFailure
Expand All @@ -215,13 +219,31 @@ extractBodyChunks // endDefinition;


extractBodyChunks0 // beginDefinition;
extractBodyChunks0[ content_String ] := content;
extractBodyChunks0[ content_List ] := extractBodyChunks0 /@ content;
extractBodyChunks0[ KeyValuePattern[ "BodyChunkProcessed" -> content_ ] ] := extractBodyChunks0 @ content;
extractBodyChunks0[ KeyValuePattern[ "ContentChunk"|"ContentDelta" -> content_ ] ] := extractBodyChunks0 @ content;
extractBodyChunks0[ KeyValuePattern @ { "Type" -> "Text", "Data" -> content_ } ] := extractBodyChunks0 @ content;
extractBodyChunks0[ KeyValuePattern @ { } ] := { };
extractBodyChunks0[ bag_Internal`Bag ] := extractBodyChunks0 @ Internal`BagPart[ bag, All ];

extractBodyChunks0[ content_String ] :=
content;

extractBodyChunks0[ content_List ] :=
extractBodyChunks0 /@ content;

extractBodyChunks0[ as: KeyValuePattern[ "ToolRequestsChunk" -> t_ ] ] :=
{ extractBodyChunks0 @ KeyDrop[ as, "ToolRequestsChunk" ], toolRequestsToStrings @ t };

extractBodyChunks0[ KeyValuePattern[ "BodyChunkProcessed" -> content_ ] ] :=
extractBodyChunks0 @ content;

extractBodyChunks0[ KeyValuePattern[ "ContentChunk"|"ContentDelta" -> content_ ] ] :=
extractBodyChunks0 @ content;

extractBodyChunks0[ KeyValuePattern @ { "Type" -> "Text", "Data" -> content_ } ] :=
extractBodyChunks0 @ content;

extractBodyChunks0[ KeyValuePattern @ { } ] :=
{ };

extractBodyChunks0[ bag_Internal`Bag ] :=
extractBodyChunks0 @ Internal`BagPart[ bag, All ];

extractBodyChunks0[ Null ] := { };

extractBodyChunks0[ Failure[
Expand All @@ -234,6 +256,67 @@ extractBodyChunks0[ fail_Failure? apiFailureQ ] :=

extractBodyChunks0 // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*toolRequestsToStrings*)
toolRequestsToStrings // beginDefinition;
toolRequestsToStrings[ requests_List ] := toolRequestsToStrings /@ requests;
toolRequestsToStrings[ req: HoldPattern[ _LLMToolRequest ] ] := toolRequestToString @ req;
toolRequestsToStrings // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*toolRequestToString*)
toolRequestToString // beginDefinition;

(* TODO: we currently only support one tool call at a time, so extras are discarded *)
toolRequestString[ req_ ] /; $receivedToolCall := "";

toolRequestToString[ req: HoldPattern[ _LLMToolRequest ] ] /; $simpleToolMethod := Enclose[
Module[ { name, tool, command, params, argString },
name = ConfirmBy[ req[ "Name" ], StringQ, "Name" ];
tool = ConfirmMatch[ getToolByName @ name, _LLMTool, "Tool" ];
command = ConfirmBy[ toolShortName @ tool, StringQ, "Command" ];
params = ToString /@ ConfirmBy[ Association @ req[ "ParameterValues" ], AssociationQ, "ParameterValues" ];
argString = ConfirmBy[ simpleParameterString @ params, StringQ, "ArgumentString" ];
$receivedToolCall = True;
StringJoin[ "\n/", command, "\n", argString ]
],
throwInternalFailure
];

toolRequestToString[ req: HoldPattern[ _LLMToolRequest ] ] := Enclose[
Module[ { name, params, json },
name = ConfirmBy[ req[ "Name" ], StringQ, "Name" ];
params = ConfirmMatch[ req[ "ParameterValues" ], KeyValuePattern @ { }, "ParameterValues" ];
json = ConfirmBy[ Developer`ToJSON @ params, StringQ, "JSON" ];
$receivedToolCall = True;
StringJoin[ "\nTOOLCALL: ", name, "\n", json, "\n", "ENDARGUMENTS" ]
],
throwInternalFailure
];

toolRequestToString // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*simpleParameterString*)
simpleParameterString // beginDefinition;

simpleParameterString[ params_Association ] /; Length @ params === 1 :=
ToString @ First @ params;

simpleParameterString[ params_Association ] /; AllTrue[ params, StringFreeQ[ "\n" ] ] :=
StringRiffle[ Values @ params, "\n" ];

simpleParameterString[ params_Association ] :=
StringRiffle[ KeyValueMap[ simpleParameterString, params ], "\n" ];

simpleParameterString[ name_String, value_ ] :=
StringJoin[ name, ": ", ToString @ value ];

simpleParameterString // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*apiFailureQ*)
Expand Down
2 changes: 1 addition & 1 deletion Source/Chatbook/Models.wl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Needs[ "Wolfram`Chatbook`UI`" ];
(* ::Section::Closed:: *)
(*Configuration*)
$defaultLLMKitService := Replace[ $llmKitService, Except[ _String ] :> "AzureOpenAI" ];
$defaultLLMKitModelName = "gpt-4o-2024-05-13";
$defaultLLMKitModelName = "gpt-4o-2024-08-06";

$$modelVersion = DigitCharacter.. ~~ (("." ~~ DigitCharacter...) | "");

Expand Down
2 changes: 2 additions & 0 deletions Source/Chatbook/Sandbox.wl
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,8 @@ preprocessSandboxString[ s_String ] := sandboxStringNormalize[ s ] = StringRepla
{
"\[FreeformPrompt][" ~~ query: Except[ "\"" ].. ~~ "]" /; StringFreeQ[ query, "[" | "]" ] :>
"\[FreeformPrompt][\"" <> query <> "\"]",
"\[FreeformPrompt]\"" ~~ query: Except[ "\"" ].. ~~ "\"" /; StringFreeQ[ query, "[" | "]" ] :>
"\[FreeformPrompt][\"" <> query <> "\"]",
("Import"|"Get") ~~ "[\"<!" ~~ uri: Except[ "!" ].. ~~ "!>\"]" :>
"InlinedExpression[\"" <> uri <> "\"]",
("Import"|"Get") ~~ "[\"!["~~___~~"](" ~~ uri: (__ ~~ "://" ~~ key__) ~~ ")\"]" /; expressionURIKeyQ @ key :>
Expand Down
8 changes: 8 additions & 0 deletions Source/Chatbook/SendChat.wl
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,14 @@ makeLLMConfiguration // beginDefinition;
makeLLMConfiguration[ as: KeyValuePattern[ "Model" -> model_String ] ] :=
makeLLMConfiguration @ Append[ as, "Model" -> { "OpenAI", model } ];

makeLLMConfiguration[ as_Association ] /; as[ "HybridToolMethod" ] :=
$lastLLMConfiguration = LLMConfiguration @ DeleteMissing @ Association[
KeyTake[ as, { "Model", "MaxTokens", "Temperature", "PresencePenalty" } ],
"Tools" -> Cases[ Flatten @ { as[ "Tools" ] }, _LLMTool ],
"StopTokens" -> makeStopTokens @ as,
"ToolMethod" -> "Service"
];

makeLLMConfiguration[ as_Association ] :=
$lastLLMConfiguration = LLMConfiguration @ DeleteMissing @ Association[
KeyTake[ as, { "Model", "MaxTokens", "Temperature", "PresencePenalty" } ],
Expand Down
44 changes: 24 additions & 20 deletions Source/Chatbook/Settings.wl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ $defaultChatSettings = <|
"FrequencyPenalty" -> 0.1,
"HandlerFunctions" :> $DefaultChatHandlerFunctions,
"HandlerFunctionsKeys" -> Automatic,
"HybridToolMethod" -> Automatic,
"IncludeHistory" -> Automatic,
"InitialChatCell" -> True,
"LLMEvaluator" -> "CodeAssistant",
Expand Down Expand Up @@ -109,6 +110,8 @@ $absoluteCurrentSettingsCache = None;
(*Argument Patterns*)
$$validRootSettingValue = Inherited | _? (AssociationQ@*Association);
$$frontEndObject = HoldPattern[ $FrontEnd | _FrontEndObject ];
$$hybridToolService = "OpenAI"|"AzureOpenAI"|"LLMKit";
$$hybridToolModel = _String | { $$hybridToolService, _ } | KeyValuePattern[ "Service" -> $$hybridToolService ];

(* ::**************************************************************************************************************:: *)
(* ::Section::Closed:: *)
Expand Down Expand Up @@ -316,16 +319,18 @@ setLLMKitFlags // endDefinition;
(* ::Subsubsection::Closed:: *)
(*overrideSettings*)
overrideSettings // beginDefinition;
overrideSettings[ settings_Association? llmKitQ ] := <| settings, $llmKitOverrides |>;
overrideSettings[ settings_Association? o1ModelQ ] := <| settings, $o1Overrides |>;
overrideSettings[ settings_Association? gpt4oTextToolsQ ] := <| settings, $gpt4oTextToolOverrides |>;
overrideSettings[ settings_Association ] := settings;

overrideSettings[ settings_Association ] := <|
settings,
If[ llmKitQ @ settings, $llmKitOverrides, <| |> ],
If[ o1ModelQ @ settings, $o1Overrides, <| |> ]
|>;

overrideSettings // endDefinition;

(* TODO: these shouldn't be mutually exclusive: *)
$llmKitOverrides = <| "Authentication" -> "LLMKit" |>;
$o1Overrides = <| "PresencePenalty" -> 0, "Temperature" -> 1 |>;
$gpt4oTextToolOverrides = <| "Model" -> <| "Service" -> "OpenAI", "Name" -> "gpt-4o-2024-05-13" |> |>;
$llmKitOverrides = <| "Authentication" -> "LLMKit" |>;
$o1Overrides = <| "PresencePenalty" -> 0, "Temperature" -> 1 |>;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsubsection::Closed:: *)
Expand All @@ -340,19 +345,6 @@ llmKitQ[ as_Association ] := TrueQ @ Or[

llmKitQ // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsubsection::Closed:: *)
(*gpt4oTextToolsQ*)
gpt4oTextToolsQ // beginDefinition;

gpt4oTextToolsQ[ settings_Association ] := TrueQ @ And[
settings[ "ToolsEnabled" ],
toModelName @ settings === "gpt-4o",
settings[ "ToolMethod" ] =!= "Service"
];

gpt4oTextToolsQ // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*evaluateSettings*)
Expand Down Expand Up @@ -380,6 +372,7 @@ resolveAutoSetting0[ as_, "DynamicAutoFormat" ] := dynamicAutoForma
resolveAutoSetting0[ as_, "EnableLLMServices" ] := $useLLMServices;
resolveAutoSetting0[ as_, "ForceSynchronous" ] := forceSynchronousQ @ as;
resolveAutoSetting0[ as_, "HandlerFunctionsKeys" ] := chatHandlerFunctionsKeys @ as;
resolveAutoSetting0[ as_, "HybridToolMethod" ] := hybridToolMethodQ @ as;
resolveAutoSetting0[ as_, "IncludeHistory" ] := Automatic;
resolveAutoSetting0[ as_, "MaxCellStringLength" ] := chooseMaxCellStringLength @ as;
resolveAutoSetting0[ as_, "MaxContextTokens" ] := autoMaxContextTokens @ as;
Expand Down Expand Up @@ -412,6 +405,7 @@ $autoSettingKeyDependencies = <|
"BypassResponseChecking" -> "ForceSynchronous",
"ForceSynchronous" -> "Model",
"HandlerFunctionsKeys" -> "EnableLLMServices",
"HybridToolMethod" -> { "Model", "ToolsEnabled" },
"MaxCellStringLength" -> { "Model", "MaxContextTokens" },
"MaxContextTokens" -> "Model",
"MaxOutputCellStringLength" -> "MaxCellStringLength",
Expand Down Expand Up @@ -444,6 +438,16 @@ $autoSettingKeyPriority := Enclose[
* ChatContextPreprompt
*)

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*hybridToolMethodQ*)
hybridToolMethodQ // beginDefinition;
hybridToolMethodQ[ KeyValuePattern[ "ToolsEnabled" -> False ] ] := False;
hybridToolMethodQ[ as_Association ] := hybridToolMethodQ[ as, as[ "Model" ] ];
hybridToolMethodQ[ as_, $$hybridToolModel ] := True;
hybridToolMethodQ[ as_, _ ] := False;
hybridToolMethodQ // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*toolCallRetryMessageQ*)
Expand Down

0 comments on commit 34ff06f

Please sign in to comment.