Skip to content

Commit 2ad1d65

Browse files
committed
Reify structured output chunks. Move JSON parsing to the depths of Completion
1 parent 80ac7d5 commit 2ad1d65

File tree

10 files changed

+269
-74
lines changed

10 files changed

+269
-74
lines changed

lib/completions/endpoints/base.rb

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def perform_completion!(
106106

107107
prompt = dialect.translate
108108

109+
structured_output = nil
110+
111+
if model_params[:response_format].present?
112+
response_structure =
113+
model_params[:response_format].dig(:json_schema, :schema, :required)
114+
115+
if response_structure.present?
116+
structured_output =
117+
DiscourseAi::Completions::StructuredOutput.new(response_structure.map(&:to_sym))
118+
end
119+
end
120+
109121
FinalDestination::HTTP.start(
110122
model_uri.host,
111123
model_uri.port,
@@ -140,10 +152,17 @@ def perform_completion!(
140152
xml_stripper =
141153
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
142154

143-
if @streaming_mode && xml_stripper
155+
if @streaming_mode
144156
blk =
145157
lambda do |partial, cancel|
146-
partial = xml_stripper << partial if partial.is_a?(String)
158+
if partial.is_a?(String)
159+
partial = xml_stripper << partial if xml_stripper
160+
161+
if structured_output.present?
162+
structured_output << partial
163+
partial = structured_output
164+
end
165+
end
147166
orig_blk.call(partial, cancel) if partial
148167
end
149168
end
@@ -167,6 +186,7 @@ def perform_completion!(
167186
xml_stripper: xml_stripper,
168187
partials_raw: partials_raw,
169188
response_raw: response_raw,
189+
structured_output: structured_output,
170190
)
171191
return response_data
172192
end
@@ -373,7 +393,8 @@ def non_streaming_response(
373393
xml_tool_processor:,
374394
xml_stripper:,
375395
partials_raw:,
376-
response_raw:
396+
response_raw:,
397+
structured_output:
377398
)
378399
response_raw << response.read_body
379400
response_data = decode(response_raw)
@@ -403,6 +424,26 @@ def non_streaming_response(
403424

404425
response_data.reject!(&:blank?)
405426

427+
if structured_output.present?
428+
has_string_response = false
429+
430+
response_data =
431+
response_data.reduce([]) do |memo, data|
432+
if data.is_a?(String)
433+
structured_output << data
434+
has_string_response = true
435+
next(memo)
436+
else
437+
memo << data
438+
end
439+
440+
memo
441+
end
442+
443+
# We only include the structured output if there was actually a structured response
444+
response_data << structured_output if has_string_response
445+
end
446+
406447
# this is to keep stuff backwards compatible
407448
response_data = response_data.first if response_data.length == 1
408449

lib/completions/endpoints/canned_response.rb

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def perform_completion!(
4040
"The number of completions you requested exceed the number of canned responses"
4141
end
4242

43-
response = transform_from_schema(response) if model_params[:response_format].present?
43+
response = as_structured_output(response) if model_params[:response_format].present?
4444

4545
raise response if response.is_a?(StandardError)
4646

@@ -56,6 +56,8 @@ def perform_completion!(
5656
yield(response, cancel_fn)
5757
elsif is_thinking?(response)
5858
yield(response, cancel_fn)
59+
elsif is_structured_output?(response)
60+
yield(response, cancel_fn)
5961
else
6062
response.each_char do |char|
6163
break if cancelled
@@ -83,11 +85,18 @@ def is_tool?(response)
8385
response.is_a?(DiscourseAi::Completions::ToolCall)
8486
end
8587

86-
def transform_from_schema(response)
87-
key = model_params[:response_format].dig(:json_schema, :schema, :properties)&.keys&.first
88-
return response if key.nil?
88+
def is_structured_output?(response)
89+
response.is_a?(DiscourseAi::Completions::StructuredOutput)
90+
end
91+
92+
def as_structured_output(response)
93+
keys = model_params[:response_format].dig(:json_schema, :schema, :properties)&.keys
94+
return response if keys.blank?
95+
96+
output = DiscourseAi::Completions::StructuredOutput.new(keys)
97+
output << { keys.first => response }.to_json
8998

90-
{ key => response }.to_json
99+
output
91100
end
92101
end
93102
end

lib/completions/structured_output.rb

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
class StructuredOutput
6+
def initialize(property_names)
7+
@raw_response = +""
8+
@state = :awaiting_key
9+
@current_key = +""
10+
@escape = false
11+
12+
@full_output =
13+
property_names.reduce({}) do |memo, pn|
14+
memo[pn.to_sym] = +""
15+
memo
16+
end
17+
18+
# Partial output is what we processed in the last chunk.
19+
@partial_output_proto = @full_output.deep_dup
20+
@last_chunk_output = @full_output.deep_dup
21+
end
22+
23+
attr_reader :full_output, :last_chunk_output
24+
25+
def <<(raw)
26+
@raw_response << raw
27+
28+
@last_chunk_output = @partial_output_proto.deep_dup
29+
30+
raw.each_char do |char|
31+
case @state
32+
when :awaiting_key
33+
if char == "\""
34+
@current_key = +""
35+
@state = :parsing_key
36+
@escape = false
37+
end
38+
when :parsing_key
39+
if char == "\""
40+
@state = :awaiting_colon
41+
else
42+
@current_key << char
43+
end
44+
when :awaiting_colon
45+
@state = :awaiting_value if char == ":"
46+
when :awaiting_value
47+
if char == '"'
48+
@escape = false
49+
@state = :parsing_value
50+
end
51+
when :parsing_value
52+
if @escape
53+
# Don't add escape sequence until we know what it is
54+
unescaped = unescape_char(char)
55+
@full_output[@current_key.to_sym] << unescaped
56+
@last_chunk_output[@current_key.to_sym] << unescaped
57+
58+
@escape = false
59+
elsif char == "\\"
60+
@escape = true
61+
elsif char == "\""
62+
@state = :awaiting_key_or_end
63+
else
64+
@full_output[@current_key.to_sym] << char
65+
@last_chunk_output[@current_key.to_sym] << char
66+
end
67+
when :awaiting_key_or_end
68+
@state = :awaiting_key if char == ","
69+
# End of object or whitespace ignored here
70+
else
71+
next
72+
end
73+
end
74+
end
75+
76+
private
77+
78+
def unescape_char(char)
79+
chars = {
80+
'"' => '"',
81+
'\\' => '\\',
82+
"/" => "/",
83+
"b" => "\b",
84+
"f" => "\f",
85+
"n" => "\n",
86+
"r" => "\r",
87+
"t" => "\t",
88+
}
89+
90+
chars[char] || char
91+
end
92+
end
93+
end
94+
end

lib/personas/bot.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def reply(context, llm_args: {}, &update_blk)
151151
raw_context << partial
152152
current_thinking << partial
153153
end
154+
elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
155+
update_blk.call(partial.last_chunk_output, cancel, nil, :structured_output)
154156
else
155157
update_blk.call(partial, cancel)
156158
end

lib/summarization/fold_content.rb

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ def initialize(bot, strategy, persist_summaries: true)
2727
def summarize(user, &on_partial_blk)
2828
truncated_content = content_to_summarize.map { |cts| truncate(cts) }
2929

30-
# Done here to cover non-streaming mode.
31-
json_reply_end = "\"}"
32-
summary = fold(truncated_content, user, &on_partial_blk).chomp(json_reply_end)
30+
summary = fold(truncated_content, user, &on_partial_blk)
3331

3432
if persist_summaries
3533
AiSummary.store!(strategy, llm_model, summary, truncated_content, human: user&.human?)
@@ -113,67 +111,24 @@ def fold(items, user, &on_partial_blk)
113111

114112
summary = +""
115113

116-
# Auxiliary variables to get the summary content from the JSON response.
117-
json_start_buffer = +""
118-
json_start_found = false
119-
# { is optional because Claude uses prefill, so it's not incldued.
120-
# TODO(roman): Maybe extraction should happen in the bot?
121-
json_summary_schema_keys = bot.persona.response_format&.first.to_h
122-
json_reply_start_regex = /\{?\s*"#{json_summary_schema_keys[:key]}"\s*:\s*"/
123-
# We need to buffer escaped newlines as the API likes to send \\ and n in different chunks.
124-
partial_unescape_buffer = +""
125-
unescape_regex = %r{\\(["/bfnrt])}
126-
json_reply_end = "\"}"
127-
128114
buffer_blk =
129115
Proc.new do |partial, cancel, _, type|
130-
if type.blank?
131-
if bot.returns_json?
132-
# Extract summary from JSON.
133-
if json_start_found
134-
if partial.end_with?("\\")
135-
partial_unescape_buffer << partial
136-
else
137-
unescaped_partial = partial_unescape_buffer
138-
139-
buffered_newline = !partial_unescape_buffer.empty? && partial.first == "n"
140-
if buffered_newline
141-
unescaped_partial << partial.first
142-
143-
unescaped_partial = unescaped_partial.gsub("\\n", "\n")
144-
unescaped_partial << partial[1..].to_s
145-
else
146-
unescaped_partial << partial.gsub("\\n", "\n")
147-
end
148-
partial_unescape_buffer = +""
149-
150-
summary << unescaped_partial
151-
152-
on_partial_blk.call(unescaped_partial, cancel) if on_partial_blk
153-
end
154-
else
155-
json_start_buffer << partial
156-
157-
if json_start_buffer.match?(json_reply_start_regex)
158-
buffered_start = json_start_buffer.gsub(json_reply_start_regex, "")
159-
summary << buffered_start
160-
161-
on_partial_blk.call(buffered_start, cancel) if on_partial_blk
162-
163-
json_start_found = true
164-
end
165-
end
166-
else
167-
# Assume response is a regular completion.
168-
summary << partial
169-
on_partial_blk.call(partial, cancel) if on_partial_blk
170-
end
116+
if type == :structured_output
117+
json_summary_schema_key = bot.persona.response_format&.first.to_h
118+
partial_summary = partial[json_summary_schema_key[:key].to_sym]
119+
120+
summary << partial_summary
121+
on_partial_blk.call(partial_summary, cancel) if on_partial_blk
122+
elsif type.blank?
123+
# Assume response is a regular completion.
124+
summary << partial
125+
on_partial_blk.call(partial, cancel) if on_partial_blk
171126
end
172127
end
173128

174129
bot.reply(context, &buffer_blk)
175130

176-
summary.chomp(json_reply_end)
131+
summary
177132
end
178133

179134
def available_tokens

spec/lib/completions/endpoints/anthropic_spec.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,15 +837,15 @@
837837
},
838838
).to_return(status: 200, body: body)
839839

840-
result = +""
840+
structured_output = nil
841841
llm.generate(
842842
prompt,
843843
user: Discourse.system_user,
844844
feature_name: "testing",
845845
response_format: schema,
846-
) { |partial, cancel| result << partial }
846+
) { |partial, cancel| structured_output = partial }
847847

848-
expect(result).to eq("\"key\":\"Hello!\"}")
848+
expect(structured_output.full_output).to eq({ key: "Hello!" })
849849

850850
expected_body = {
851851
model: "claude-3-opus-20240229",

spec/lib/completions/endpoints/aws_bedrock_spec.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,9 @@ def encode_message(message)
591591
end
592592
.to_return(status: 200, body: messages)
593593

594-
response = +""
594+
structured_output = nil
595595
proxy.generate("hello world", response_format: schema, user: user) do |partial|
596-
response << partial
596+
structured_output = partial
597597
end
598598

599599
expected = {
@@ -607,7 +607,7 @@ def encode_message(message)
607607
}
608608
expect(JSON.parse(request.body)).to eq(expected)
609609

610-
expect(response).to eq("\"key\":\"Hello!\"}")
610+
expect(structured_output.full_output).to eq({ key: "Hello!" })
611611
end
612612
end
613613
end

0 commit comments

Comments
 (0)