Skip to content

Commit bb8d463

Browse files
authored
Merge pull request #1018 from WolframResearch/feature/register-documentation-source
Added `RegisterVectorDatabase`
2 parents a5e4f75 + 86a86a2 commit bb8d463

File tree

6 files changed

+400
-59
lines changed

6 files changed

+400
-59
lines changed

Source/Chatbook/Main.wl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ BeginPackage[ "Wolfram`Chatbook`" ];
7676
`LogChatTiming;
7777
`MakeExpressionURI;
7878
`RebuildChatSearchIndex;
79+
`RegisterVectorDatabase;
7980
`RelatedDocumentation;
8081
`RelatedWolframAlphaQueries;
8182
`RemoveChatFromSearchIndex;
@@ -245,6 +246,7 @@ $ChatbookProtectedNames = "Wolfram`Chatbook`" <> # & /@ {
245246
"LogChatTiming",
246247
"MakeExpressionURI",
247248
"RebuildChatSearchIndex",
249+
"RegisterVectorDatabase",
248250
"RelatedDocumentation",
249251
"RelatedWolframAlphaQueries",
250252
"RemoveChatFromSearchIndex",

Source/Chatbook/PromptGenerators/Common.wl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ BeginPackage[ "Wolfram`Chatbook`PromptGenerators`Common`" ];
44

55
HoldComplete[
66
`$$prompt,
7+
`$defaultSources,
78
`getSmallContextString,
9+
`getSnippets,
810
`insertContextPrompt,
911
`vectorDBSearch
1012
];

Source/Chatbook/PromptGenerators/RelatedDocumentation.wl

Lines changed: 141 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ $unfilteredItemsPerSource = 10;
3939
(* ::**************************************************************************************************************:: *)
4040
(* ::Section::Closed:: *)
4141
(*Messages*)
42-
Chatbook::CloudDownloadError = "Unable to download required data from the cloud. Please try again later.";
43-
Chatbook::InvalidSources = "Invalid value for the \"Sources\" option: `1`.";
42+
Chatbook::CloudDownloadError = "Unable to download required data from the cloud. Please try again later.";
43+
Chatbook::InvalidSources = "Invalid value for the \"Sources\" option: `1`.";
44+
Chatbook::SnippetFunctionOutputFailure = "The snippet function `1` returned a list of length `2` for `3` values.";
45+
Chatbook::SnippetFunctionLengthFailure = "The snippet function `1` returned a list of length `2` for `3` values.";
4446

4547
(* ::**************************************************************************************************************:: *)
4648
(* ::Section::Closed:: *)
4749
(*$RelatedDocumentationSources*)
48-
$RelatedDocumentationSources = $defaultSources;
50+
$RelatedDocumentationSources := $defaultSources;
4951

5052
(* ::**************************************************************************************************************:: *)
5153
(* ::Section::Closed:: *)
@@ -94,20 +96,28 @@ RelatedDocumentation[ prompt_, Automatic, count_, opts: OptionsPattern[ ] ] :=
9496

9597
RelatedDocumentation[ prompt: $$prompt, "URIs", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[
9698
(* TODO: filter results *)
97-
ConfirmMatch[ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ], { ___String }, "Queries" ],
99+
URL /@ ConfirmMatch[
100+
vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ],
101+
{ ___String },
102+
"Values"
103+
],
98104
throwInternalFailure
99105
];
100106

101107
RelatedDocumentation[ All, "URIs", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[
102108
(* TODO: filter results *)
103-
Union @ ConfirmMatch[ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], All ], { __String }, "QueryList" ],
109+
URL /@ Union @ ConfirmMatch[
110+
vectorDBSearch[ getSources @ OptionValue[ "Sources" ], All ],
111+
{ ___String },
112+
"Values"
113+
],
104114
throwInternalFailure
105115
];
106116

107117
RelatedDocumentation[ prompt: $$prompt, "Snippets", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[
108118
ConfirmMatch[
109119
(* TODO: filter results *)
110-
DeleteMissing[ makeDocSnippets @ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ] ],
120+
DeleteMissing[ makeDocSnippets @ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Results" ] ],
111121
{ ___String },
112122
"Snippets"
113123
],
@@ -118,7 +128,7 @@ RelatedDocumentation[ prompt_, property_, UpTo[ n_Integer ], opts: OptionsPatter
118128
catchMine @ RelatedDocumentation[ prompt, property, n, opts ];
119129

120130
RelatedDocumentation[ prompt_, property_, n_Integer, opts: OptionsPattern[ ] ] := catchMine @ Enclose[
121-
Take[ ConfirmMatch[ RelatedDocumentation[ prompt, property, Automatic, opts ], { ___String } ], UpTo @ n ],
131+
Take[ ConfirmBy[ RelatedDocumentation[ prompt, property, Automatic, opts ], ListQ ], UpTo @ n ],
122132
throwInternalFailure
123133
];
124134

@@ -163,7 +173,8 @@ RelatedDocumentation[ prompt_, "Prompt", n_Integer, opts: OptionsPattern[ ] ] :=
163173
$rerankMethod = Replace[
164174
OptionValue[ "RerankMethod" ],
165175
$$unspecified :> CurrentChatSettings[ "DocumentationRerankMethod" ]
166-
]
176+
],
177+
$RelatedDocumentationSources = getSources @ OptionValue[ "Sources" ]
167178
},
168179
relatedDocumentationPrompt[
169180
ensureChatMessages @ prompt,
@@ -223,23 +234,25 @@ ensureChatMessages // endDefinition;
223234
relatedDocumentationPrompt // beginDefinition;
224235

225236
relatedDocumentationPrompt[ messages: $$chatMessages, count_, filter_, filterCount_ ] := Enclose[
226-
Catch @ Module[ { uris, filtered, string },
237+
Catch @ Module[ { results, filtered, string },
227238

228-
uris = ConfirmMatch[
229-
RelatedDocumentation[ messages, "URIs", count ],
230-
{ ___String },
231-
"URIs"
232-
] // LogChatTiming[ "RelatedDocumentationURIs" ] // withApproximateProgress[ "CheckingDocumentation", 0.2 ];
239+
results = ConfirmMatch[
240+
RelatedDocumentation[ messages, "Results", count ],
241+
{ ___Association },
242+
"Results"
243+
] // LogChatTiming[ "RelatedDocumentationResults" ] // withApproximateProgress[ "CheckingDocumentation", 0.2 ];
244+
245+
If[ results === { }, Throw[ "" ] ];
233246

234-
If[ uris === { }, Throw[ "" ] ];
247+
results = DeleteDuplicatesBy[ results, Lookup[ "Value" ] ];
235248

236249
filtered = ConfirmMatch[
237-
filterSnippets[ messages, uris, filter, filterCount ] // LogChatTiming[ "FilterSnippets" ],
250+
filterSnippets[ messages, results, filter, filterCount ] // LogChatTiming[ "FilterSnippets" ],
238251
{ ___String },
239252
"Filtered"
240253
];
241254

242-
string = StringTrim @ StringRiffle[ "# "<># & /@ DeleteCases[ filtered, "" ], "\n\n======\n\n" ];
255+
string = StringTrim @ StringRiffle[ DeleteCases[ filtered, "" ], "\n\n======\n\n" ];
243256

244257
If[ string === "",
245258
"",
@@ -272,20 +285,20 @@ $relatedDocsStringUnfilteredHeader =
272285
filterSnippets // beginDefinition;
273286

274287

275-
filterSnippets[ messages_, uris: { __String }, Except[ True ], filterCount_ ] := Enclose[
276-
ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ],
288+
filterSnippets[ messages_, results_List, Except[ True ], filterCount_ ] := Enclose[
289+
ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ],
277290
throwInternalFailure
278291
];
279292

280293

281294
filterSnippets[
282295
messages_,
283-
uris: { __String },
296+
results_List,
284297
True,
285298
filterCount_Integer? Positive
286299
] /; $rerankMethod === None := Enclose[
287300
Catch @ Module[ { snippets },
288-
snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ];
301+
snippets = ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ];
289302
Take[ snippets, UpTo[ filterCount ] ]
290303
],
291304
throwInternalFailure
@@ -294,13 +307,13 @@ filterSnippets[
294307

295308
filterSnippets[
296309
messages_,
297-
uris: { __String },
310+
results_List,
298311
True,
299312
filterCount_Integer? Positive
300313
] /; $rerankMethod === "rerank-english-v3.0" (* EXPERIMENTAL *) := Enclose[
301-
Catch @ Module[ { snippets, inserted, transcript, instructions, resp, results, idx, ranked },
314+
Catch @ Module[ { snippets, inserted, transcript, instructions, resp, respResults, idx, ranked },
302315

303-
snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ];
316+
snippets = ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ];
304317
setProgressDisplay[ "ProgressTextChoosingDocumentation" ];
305318
inserted = insertContextPrompt @ messages;
306319
transcript = ConfirmBy[ getSmallContextString @ inserted, StringQ, "Transcript" ];
@@ -323,26 +336,33 @@ filterSnippets[
323336

324337
If[ FailureQ @ resp, throwTop @ resp ];
325338

326-
results = ConfirmMatch[ resp[ "results" ], { __Association }, "Results" ];
339+
respResults = ConfirmMatch[ resp[ "results" ], { __Association }, "Results" ];
327340

328341
idx = ConfirmMatch[
329-
Select[ results, #[ "relevance_score" ] > 0.01 & ][[ All, "index" ]] + 1,
342+
Select[ respResults, #[ "relevance_score" ] > 0.01 & ][[ All, "index" ]] + 1,
330343
{ ___Integer },
331344
"Indices"
332345
];
333346

334347
ranked = ConfirmMatch[ snippets[[ idx ]], { ___String }, "Ranked" ];
335348

349+
(* FIXME: need to add handler data here *)
350+
336351
Take[ ranked, UpTo[ filterCount ] ]
337352
],
338353
throwInternalFailure
339354
];
340355

341356

342-
filterSnippets[ messages_, uris: { __String }, True, filterCount_Integer? Positive ] := Enclose[
343-
Catch @ Module[ { snippets, inserted, transcript, xml, instructions, response, pages },
357+
filterSnippets[ messages_, results0_List, True, filterCount_Integer? Positive ] := Enclose[
358+
Catch @ Module[
359+
{
360+
results, snippets, inserted, transcript, xml,
361+
instructions, response, uriToSnippet, uris, selected, pages
362+
},
344363

345-
snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ];
364+
results = ConfirmMatch[ addDocSnippets @ results0, { ___Association }, "Results" ];
365+
snippets = ConfirmMatch[ Lookup[ results, "Snippet" ], { ___String }, "Snippets" ];
346366
setProgressDisplay[ "ChoosingDocumentation" ];
347367
inserted = insertContextPrompt @ messages;
348368
transcript = ConfirmBy[ getSmallContextString @ inserted, StringQ, "Transcript" ];
@@ -370,7 +390,15 @@ filterSnippets[ messages_, uris: { __String }, True, filterCount_Integer? Positi
370390
"Response"
371391
];
372392

373-
pages = ConfirmMatch[ makeDocSnippets @ selectSnippetsFromJSON[ response, uris ], { ___String }, "Pages" ];
393+
$lastFilterInstructions = instructions;
394+
$lastFilterResponse = response;
395+
396+
uriToSnippet = <| #Value -> #Snippet & /@ results |>;
397+
uris = ConfirmMatch[ Keys @ uriToSnippet, { ___String }, "URIs" ];
398+
selected = ConfirmMatch[ selectSnippetsFromJSON[ response, uris ], { ___String }, "Pages" ];
399+
pages = ConfirmMatch[ Lookup[ uriToSnippet, selected ], { ___String }, "Pages" ];
400+
401+
addHandlerArguments[ "RelatedDocumentation" -> <| "Results" -> uris, "Filtered" -> selected |> ];
374402

375403
pages
376404
],
@@ -387,6 +415,7 @@ Your task is to read a chat transcript between a user and assistant, and then se
387415
documentation snippets that could help the assistant answer the user's latest message.
388416
389417
Each snippet is uniquely identified by a URI (always starts with 'paclet:' or 'https://*.wolframcloud.com').
418+
You must also include the fragment appearing after the '#' in the URI.
390419
391420
Choose up to %%FilteredCount%% documentation snippets that would help answer the user's MOST RECENT message.
392421
@@ -486,26 +515,100 @@ snippetXML // endDefinition;
486515
(*Documentation Snippets*)
487516
$documentationSnippets = <| |>;
488517

518+
(* ::**************************************************************************************************************:: *)
519+
(* ::Subsection::Closed:: *)
520+
(*addDocSnippets*)
521+
addDocSnippets // beginDefinition;
522+
523+
addDocSnippets[ results: { ___Association } ] := Enclose[
524+
Module[ { withOrdering, grouped, withSnippets, sorted },
525+
526+
withOrdering = MapIndexed[ <| "Position" -> First[ #2 ], #1 |> &, results ];
527+
grouped = GroupBy[ withOrdering, Lookup[ "SnippetFunction" ] ];
528+
529+
withSnippets = ConfirmMatch[
530+
Flatten @ KeyValueMap[ applySnippetFunction, grouped ],
531+
{ ___Association },
532+
"WithSnippets"
533+
];
534+
535+
sorted = ConfirmMatch[ SortBy[ withSnippets, Lookup[ "Position" ] ], { ___Association }, "Sorted" ];
536+
537+
ConfirmAssert[ Length @ sorted === Length @ results, "LengthCheck" ];
538+
539+
sorted
540+
],
541+
throwInternalFailure
542+
];
543+
544+
addDocSnippets // endDefinition;
545+
489546
(* ::**************************************************************************************************************:: *)
490547
(* ::Subsection::Closed:: *)
491548
(*makeDocSnippets*)
492549
makeDocSnippets // beginDefinition;
493550

494-
makeDocSnippets[ uris0: { ___String } ] := Enclose[
495-
Module[ { uris, data, snippets, strings },
496-
uris = DeleteDuplicates @ uris0;
551+
makeDocSnippets[ results: { ___Association } ] := Enclose[
552+
Module[ { sorted, snippets },
553+
sorted = ConfirmMatch[ addDocSnippets @ results, { ___Association }, "Sorted" ];
554+
snippets = ConfirmMatch[ Lookup[ sorted, "Snippet" ], { ___String }, "Snippets" ];
555+
ConfirmAssert[ Length @ snippets === Length @ results, "LengthCheck" ];
556+
DeleteDuplicates @ snippets
557+
],
558+
throwInternalFailure
559+
];
560+
561+
makeDocSnippets // endDefinition;
562+
563+
(* ::**************************************************************************************************************:: *)
564+
(* ::Subsubsection::Closed:: *)
565+
(*applySnippetFunction*)
566+
applySnippetFunction // beginDefinition;
567+
568+
applySnippetFunction[ f_, { } ] := { };
569+
570+
applySnippetFunction[ f_, data: { ___Association } ] := Enclose[
571+
Module[ { values, snippets, snippetLen, valuesLen },
572+
573+
values = ConfirmMatch[ Lookup[ data, "Value" ], { ___String }, "Values" ];
574+
snippets = f @ values;
575+
snippetLen = Length @ snippets;
576+
valuesLen = Length @ values;
577+
578+
If[ ! MatchQ[ snippets, { ___String } ], throwFailure[ "SnippetFunctionOutputFailure", f, snippets ] ];
579+
If[ snippetLen =!= valuesLen, throwFailure[ "SnippetFunctionLengthFailure", f, snippetLen, valuesLen ] ];
580+
581+
ConfirmBy[
582+
Association /@ Transpose @ { data, Thread[ "Snippet" -> snippets ] },
583+
AllTrue @ AssociationQ,
584+
"Result"
585+
]
586+
] // LogChatTiming @ { "ApplySnippetFunction", f },
587+
throwInternalFailure
588+
];
589+
590+
applySnippetFunction // endDefinition;
591+
592+
(* ::**************************************************************************************************************:: *)
593+
(* ::Subsection::Closed:: *)
594+
(*getSnippets*)
595+
getSnippets // beginDefinition;
596+
597+
getSnippets[ uris: { ___String } ] := Enclose[
598+
Module[ { data, snippets, strings },
497599
data = ConfirmBy[ getDocumentationSnippetData @ uris, AssociationQ, "Data" ];
498-
snippets = ConfirmMatch[ Values @ data, { ___Association }, "Snippets" ];
600+
snippets = ConfirmMatch[ Lookup[ data, uris ], { ___Association }, "Snippets" ];
499601
strings = ConfirmMatch[ Lookup[ "String" ] /@ snippets, { ___String }, "Strings" ];
500-
strings
602+
ConfirmAssert[ Length @ strings === Length @ uris, "LengthCheck" ];
603+
"# " <> # & /@ strings
501604
],
502605
throwInternalFailure
503606
];
504607

505-
makeDocSnippets[ uri_String ] :=
506-
First @ makeDocSnippets @ { uri };
608+
getSnippets[ uri_String ] :=
609+
First @ getSnippets @ { uri };
507610

508-
makeDocSnippets // endDefinition;
611+
getSnippets // endDefinition;
509612

510613
(* ::**************************************************************************************************************:: *)
511614
(* ::Subsection::Closed:: *)

0 commit comments

Comments
 (0)