Skip to content

Commit 44a3679

Browse files
committed
Structured ollama
* Activate structured output for ollamaChat * Update documentation
1 parent 3b1db9d commit 44a3679

10 files changed

+225
-160
lines changed

+llms/+internal/callOllamaChatAPI.m

+9
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@
8484
parameters.stream = ~isempty(nvp.StreamFun);
8585

8686
options = struct;
87+
88+
if strcmp(nvp.ResponseFormat,"json")
89+
parameters.format = struct('type','json_object');
90+
elseif isstruct(nvp.ResponseFormat)
91+
parameters.format = llms.internal.jsonSchemaFromPrototype(nvp.ResponseFormat);
92+
elseif startsWith(string(nvp.ResponseFormat), asManyOfPattern(whitespacePattern)+"{")
93+
parameters.format = llms.internal.verbatimJSON(nvp.ResponseFormat);
94+
end
95+
8796
if ~isempty(nvp.Seed)
8897
options.seed = nvp.Seed;
8998
end

+llms/+internal/useSameFieldTypes.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
case "struct"
2222
prototype = prototype(1);
2323
if isscalar(data)
24-
if isequal(fieldnames(data),fieldnames(prototype))
24+
if isequal(sort(fieldnames(data)),sort(fieldnames(prototype)))
2525
for field_c = fieldnames(data).'
2626
field = field_c{1};
2727
data.(field) = alignTypes(data.(field),prototype.(field));

+llms/+utils/errorMessageCatalog.m

+1
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@
6666
catalog("llms:stream:responseStreamer:InvalidInput") = "Input does not have the expected json format, got ""{1}"".";
6767
catalog("llms:unsupportedDatatypeInPrototype") = "Invalid data type ''{1}'' in prototype. Prototype must be a struct, composed of numerical, string, logical, categorical, or struct.";
6868
catalog("llms:incorrectResponseFormat") = "Invalid response format. Response format must be ""text"", ""json"", a struct, or a string with a JSON Schema definition.";
69+
catalog("llms:OllamaStructuredOutputNeeds05") = "Structured output is not supported for Ollama version {1}. Use version 0.5.0 or newer.";
6970
end
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function tf = requestsStructuredOutput(format)
2+
% This function is undocumented and will change in a future release
3+
4+
% Simple function to check if requested format triggers structured output
5+
6+
% Copyright 2024 The MathWorks, Inc.
7+
tf = isstruct(format) || startsWith(format,asManyOfPattern(whitespacePattern)+"{");
8+
end
9+

doc/functions/ollamaChat.md

+21-4
Original file line numberDiff line numberDiff line change
@@ -139,23 +139,40 @@ If the server does not respond within the timeout, then the function throws an e
139139

140140
### `ResponseFormat` — Response format
141141

142-
`"text"` (default) | `"json"`
142+
`"text"` (default) | `"json"` | string scalar | structure array
143143

144144

145145
After construction, this property is read\-only.
146146

147147

148-
Format of generated output.
148+
Format of the `generatedOutput` output argument of the `generate` function. You can request unformatted output, JSON mode, or structured output.
149149

150150

151-
If you set the response format to `"text"`, then the generated output is a string.
151+
#### Unformatted Output
152152

153153

154-
If you set the response format to `"json"`, then the generated output is a string containing JSON encoded data.
154+
If you set the response format to `"text"`, then the generated output is an unformatted string.
155+
156+
157+
#### JSON Mode
158+
159+
160+
If you set the response format to `"json"`, then the generated output is a formatted string containing JSON encoded data.
155161

156162

157163
To configure the format of the generated JSON file, describe the format using natural language and provide it to the model either in the system prompt or as a user message. The prompt or message describing the format must contain the word `"json"` or `"JSON"`.
158164

165+
#### Structured Output
166+
167+
168+
This option is only supported for Ollama version 0.5.0 and later.
169+
170+
171+
To ensure that the model follows the required format, use structured output. To do this, set `ReponseFormat` to:
172+
173+
- A string scalar containing a valid JSON Schema.
174+
- A structure array containing an example that adheres to the required format, for example: `ResponseFormat=struct("Name","Rudolph","NoseColor",[255 0 0])`
175+
159176
# Other Properties
160177
### `SystemPrompt` — System prompt
161178

ollamaChat.m

+28-5
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,9 @@
173173
% value is CHAT.StopSequences.
174174
% Example: ["The end.", "And that's all she wrote."]
175175
%
176-
%
177-
% ResponseFormat - The format of response the model returns.
178-
% The default value is CHAT.ResponseFormat.
179-
% "text" (default) | "json"
176+
% ResponseFormat - The format of response the call returns.
177+
% Default value is CHAT.ResponseFormat.
178+
% "text" | "json" | struct | string with JSON Schema
180179
%
181180
% StreamFun - Function to callback when streaming the
182181
% result. The default value is CHAT.StreamFun.
@@ -193,7 +192,7 @@
193192
nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP
194193
nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK
195194
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
196-
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
195+
nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = this.ResponseFormat
197196
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
198197
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ
199198
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
@@ -234,9 +233,16 @@
234233
end
235234

236235
if isfield(response.Body.Data,"error")
236+
[versionStr, versionList] = serverVersion(nvp.Endpoint);
237+
if llms.utils.requestsStructuredOutput(nvp.ResponseFormat) && ...
238+
~versionIsAtLeast(versionList, [0,5,0])
239+
error("llms:OllamaStructuredOutputNeeds05",llms.utils.errorMessageCatalog.getMessage("llms:OllamaStructuredOutputNeeds05", versionStr));
240+
end
237241
err = response.Body.Data.error;
238242
error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err));
239243
end
244+
245+
text = llms.internal.reformatOutput(text,nvp.ResponseFormat);
240246
end
241247
end
242248

@@ -310,3 +316,20 @@ function mustBeIntegerOrEmpty(value)
310316
mustBeInteger(value)
311317
end
312318
end
319+
320+
function [versionStr, versionList] = serverVersion(endpoint)
321+
URL = endpoint + "/api/version";
322+
if ~startsWith(URL,"http")
323+
URL = "http://" + URL;
324+
end
325+
versionStr = webread(URL).version;
326+
versionList = split(versionStr,'.');
327+
versionList = str2double(versionList);
328+
end
329+
330+
function tf = versionIsAtLeast(version,minVersion)
331+
tf = version(1) > minVersion(1) || ...
332+
(version(1) == minVersion(1) && (...
333+
version(2) > minVersion(2) || ...
334+
(version(2) == minVersion(2) && version(3) >= minVersion(3))));
335+
end

tests/hopenAIChat.m

+1-148
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
classdef (Abstract) hopenAIChat < matlab.unittest.TestCase
1+
classdef (Abstract) hopenAIChat < hstructuredOutput
22
% Tests for OpenAI-based chats (openAIChat, azureChat)
33

44
% Copyright 2023-2024 The MathWorks, Inc.
@@ -17,8 +17,6 @@
1717
constructor
1818
defaultModel
1919
visionModel
20-
structuredModel
21-
noStructuredOutputModel
2220
end
2321

2422
methods(Test)
@@ -195,66 +193,6 @@ function generateOverridesProperties(testCase)
195193
testCase.verifyThat(text, EndsWithSubstring("3, "));
196194
end
197195

198-
function generateWithStructuredOutput(testCase)
199-
import matlab.unittest.constraints.IsEqualTo
200-
import matlab.unittest.constraints.StartsWithSubstring
201-
res = generate(testCase.structuredModel,"Which animal produces honey?",...
202-
ResponseFormat = struct(commonName = "dog", scientificName = "Canis familiaris"));
203-
testCase.assertClass(res,"struct");
204-
testCase.verifySize(fieldnames(res),[2,1]);
205-
testCase.verifyThat(res.commonName, IsEqualTo("Honeybee") | IsEqualTo("Honey bee") | IsEqualTo("Honey Bee"));
206-
testCase.verifyThat(res.scientificName, StartsWithSubstring("Apis"));
207-
end
208-
209-
function generateListWithStructuredOutput(testCase)
210-
prototype = struct("plantName",{"appletree","pear"}, ...
211-
"fruit",{"apple","pear"}, ...
212-
"edible",[true,true], ...
213-
"ignore", missing);
214-
res = generate(testCase.structuredModel,"What is harvested in August?", ResponseFormat = prototype);
215-
testCase.verifyCompatibleStructs(res, prototype);
216-
end
217-
218-
function generateWithNestedStructs(testCase)
219-
stepsPrototype = struct("explanation",{"a","b"},"assumptions",{"a","b"});
220-
prototype = struct("steps",stepsPrototype,"final_answer","a");
221-
res = generate(testCase.structuredModel,"What is the positive root of x^2-2*x+1?", ...
222-
ResponseFormat=prototype);
223-
testCase.verifyCompatibleStructs(res,prototype);
224-
end
225-
226-
function incompleteJSONResponse(testCase)
227-
country = ["USA";"UK"];
228-
capital = ["Washington, D.C.";"London"];
229-
population = [345716792;69203012];
230-
prototype = struct("country",country,"capital",capital,"population",population);
231-
232-
testCase.verifyError(@() generate(testCase.structuredModel, ...
233-
"What are the five largest countries whose English names" + ...
234-
" start with the letter A?", ...
235-
ResponseFormat = prototype, MaxNumTokens=3), "llms:apiReturnedIncompleteJSON");
236-
end
237-
238-
function generateWithExplicitSchema(testCase)
239-
import matlab.unittest.constraints.IsSameSetAs
240-
schema = iGetSchema();
241-
242-
genUser = generate(testCase.structuredModel,"Create a sample user",ResponseFormat=schema);
243-
genAddress = generate(testCase.structuredModel,"Create a sample address",ResponseFormat=schema);
244-
245-
testCase.verifyClass(genUser,"string");
246-
genUserDecoded = jsondecode(genUser);
247-
testCase.verifyClass(genUserDecoded.item,"struct");
248-
testCase.verifyThat(fieldnames(genUserDecoded.item),...
249-
IsSameSetAs({'name','age'}));
250-
251-
testCase.verifyClass(genAddress,"string");
252-
genAddressDecoded = jsondecode(genAddress);
253-
testCase.verifyClass(genAddressDecoded.item,"struct");
254-
testCase.verifyThat(fieldnames(genAddressDecoded.item),...
255-
IsSameSetAs({'number','street','city'}));
256-
end
257-
258196
function invalidInputsGenerate(testCase, InvalidGenerateInput)
259197
f = openAIFunction("validfunction");
260198
chat = testCase.constructor(Tools=f, APIKey="this-is-not-a-real-key");
@@ -321,89 +259,4 @@ function keyNotFound(testCase)
321259
testCase.verifyError(testCase.constructor, "llms:keyMustBeSpecified");
322260
end
323261
end
324-
325-
methods
326-
function verifyCompatibleStructs(testCase,data,prototype)
327-
import matlab.unittest.constraints.IsSameSetAs
328-
testCase.assertClass(data,"struct");
329-
if ~isscalar(data)
330-
arrayfun(@(d) testCase.verifyCompatibleStructs(d,prototype), data);
331-
return
332-
end
333-
testCase.assertClass(prototype,"struct");
334-
if ~isscalar(prototype)
335-
prototype = prototype(1);
336-
end
337-
testCase.assertThat(fieldnames(data),IsSameSetAs(fieldnames(prototype)));
338-
for name = fieldnames(data).'
339-
field = name{1};
340-
testCase.verifyClass(data.(field),class(prototype.(field)));
341-
if isstruct(data.(field))
342-
testCase.verifyCompatibleStructs(data.(field),prototype.(field));
343-
end
344-
end
345-
end
346-
end
347-
end
348-
349-
function str = iGetSchema()
350-
% an example from https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
351-
str = string(join({
352-
'{'
353-
' "type": "object",'
354-
' "properties": {'
355-
' "item": {'
356-
' "anyOf": ['
357-
' {'
358-
' "type": "object",'
359-
' "description": "The user object to insert into the database",'
360-
' "properties": {'
361-
' "name": {'
362-
' "type": "string",'
363-
' "description": "The name of the user"'
364-
' },'
365-
' "age": {'
366-
' "type": "number",'
367-
' "description": "The age of the user"'
368-
' }'
369-
' },'
370-
' "additionalProperties": false,'
371-
' "required": ['
372-
' "name",'
373-
' "age"'
374-
' ]'
375-
' },'
376-
' {'
377-
' "type": "object",'
378-
' "description": "The address object to insert into the database",'
379-
' "properties": {'
380-
' "number": {'
381-
' "type": "string",'
382-
' "description": "The number of the address. Eg. for 123 main st, this would be 123"'
383-
' },'
384-
' "street": {'
385-
' "type": "string",'
386-
' "description": "The street name. Eg. for 123 main st, this would be main st"'
387-
' },'
388-
' "city": {'
389-
' "type": "string",'
390-
' "description": "The city of the address"'
391-
' }'
392-
' },'
393-
' "additionalProperties": false,'
394-
' "required": ['
395-
' "number",'
396-
' "street",'
397-
' "city"'
398-
' ]'
399-
' }'
400-
' ]'
401-
' }'
402-
' },'
403-
' "additionalProperties": false,'
404-
' "required": ['
405-
' "item"'
406-
' ]'
407-
'}'
408-
}, newline));
409262
end

0 commit comments

Comments
 (0)