Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions go/ai/format_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,30 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {

return m, nil
}

// ParseChunk parses a streaming chunk and returns parsed array data.
// Based on JS version: js/ai/src/formats/array.ts parseChunk method
func (a arrayHandler) ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error) {
if chunk == nil || len(chunk.Content) == 0 {
return nil, nil
}

// Try to extract array items from accumulated text
items := base.ExtractItems(accumulatedText)
if len(items) > 0 {
return items, nil
}

// If no items found, try partial JSON parsing
data, err := base.ParsePartialJSON(accumulatedText)
if err != nil {
return nil, nil
}

// Check if data is an array
if arr, ok := data.([]interface{}); ok {
return arr, nil
}

return nil, nil
}
28 changes: 28 additions & 0 deletions go/ai/format_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,34 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}

// ParseChunk parses a streaming chunk and returns parsed enum data.
// Based on JS version: js/ai/src/formats/enum.ts parseChunk method
func (e enumHandler) ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error) {
if chunk == nil || len(chunk.Content) == 0 {
return nil, nil
}

// Clean and trim the accumulated text
re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(accumulatedText, "")
trimmed := strings.TrimSpace(clean)

// Check if the trimmed text matches any enum value
if slices.Contains(e.enums, trimmed) {
return trimmed, nil
}

// Check if any enum starts with the current accumulated text (partial match)
for _, enum := range e.enums {
if strings.HasPrefix(enum, trimmed) {
// Still accumulating, return nil
return nil, nil
}
}

return nil, nil
}

// Get enum strings from json schema
func objectEnums(schema map[string]any) []string {
var enums []string
Expand Down
40 changes: 40 additions & 0 deletions go/ai/format_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,43 @@ func (j jsonHandler) ParseMessage(m *Message) (*Message, error) {

return m, nil
}

// ParseChunk parses a streaming chunk and returns parsed JSON data.
// Based on JS version: js/ai/src/formats/json.ts parseChunk method
func (j jsonHandler) ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error) {
if chunk == nil || len(chunk.Content) == 0 {
return nil, nil
}

// Try to extract JSON from accumulated text
data, err := base.ExtractJSON(accumulatedText)
if err != nil {
// If extraction fails, try partial JSON parsing
data, err = base.ParsePartialJSON(accumulatedText)
if err != nil {
return nil, nil
}
}

// If we have a schema, validate the data
if j.config.Schema != nil && data != nil {
jsonBytes, err := json.Marshal(data)
if err != nil {
return nil, err
}

schemaBytes, err := json.Marshal(j.config.Schema)
if err != nil {
return nil, err
}

// Only validate if we have complete JSON (not partial)
if base.ValidJSON(string(jsonBytes)) {
if err := base.ValidateRaw(jsonBytes, schemaBytes); err != nil {
return nil, err
}
}
}

return data, nil
}
26 changes: 26 additions & 0 deletions go/ai/format_jsonl.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,29 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) {

return m, nil
}

// ParseChunk parses a streaming chunk and returns parsed JSONL data.
// Based on JS version: js/ai/src/formats/jsonl.ts parseChunk method
func (j jsonlHandler) ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error) {
if chunk == nil || len(chunk.Content) == 0 {
return nil, nil
}

// Extract JSON objects from accumulated text (one per line)
lines := base.GetJsonObjectLines(accumulatedText)
if len(lines) > 0 {
// For JSONL, return array of parsed objects
var items []interface{}
for _, line := range lines {
var item interface{}
if err := json.Unmarshal([]byte(line), &item); err == nil {
items = append(items, item)
}
}
if len(items) > 0 {
return items, nil
}
}

return nil, nil
}
5 changes: 5 additions & 0 deletions go/ai/format_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ func (t textHandler) Instructions() string {
func (t textHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}

// ParseChunk for text format simply returns the accumulated text
func (t textHandler) ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error) {
return accumulatedText, nil
}
3 changes: 3 additions & 0 deletions go/ai/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ type Formatter interface {
type FormatHandler interface {
// ParseMessage parses the message and returns a new formatted message.
ParseMessage(message *Message) (*Message, error)
// ParseChunk parses a streaming chunk and returns parsed data (optional).
// Based on JS version: js/ai/src/formats/types.ts parseChunk method
ParseChunk(chunk *ModelResponseChunk, accumulatedText string) (interface{}, error)
// Instructions returns the formatter instructions to embed in the prompt.
Instructions() string
// Config returns the output config for the model request.
Expand Down
4 changes: 4 additions & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ type ModelResponseChunk struct {
Custom any `json:"custom,omitempty"`
Index int `json:"index,omitempty"`
Role Role `json:"role,omitempty"`

// Metadata holds additional information for streaming.
// Based on JS version: js/ai/src/generate/chunk.ts
Metadata map[string]any `json:"metadata,omitempty"`
}

// OutputConfig describes the structure that the model's output
Expand Down
84 changes: 83 additions & 1 deletion go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,93 @@ func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOpt
}

// Generate run generate request for this model. Returns ModelResponse struct.
// TODO: Stream GenerateData with partial JSON
func GenerateData[Out any](ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
var value Out
opts = append(opts, WithOutputType(value))

// Parse options to get streaming callback and output format
genOpts := &generateOptions{}
for _, opt := range opts {
if err := opt.applyGenerate(genOpts); err != nil {
return nil, nil, err
}
}

// If streaming is requested, wrap the callback to support partial JSON parsing
if genOpts.Stream != nil {
// Determine the output format
outputFormat := genOpts.OutputFormat
if outputFormat == "" && genOpts.OutputSchema != nil {
outputFormat = OutputFormatJSON
}

// Get the appropriate formatter
var formatter Formatter
var formatHandler FormatHandler
if outputFormat != "" {
var err error
formatter, err = resolveFormat(r, genOpts.OutputSchema, outputFormat)
if err == nil && formatter != nil {
formatHandler, err = formatter.Handler(genOpts.OutputSchema)
if err != nil {
// Log error but continue without formatter
formatHandler = nil
}
}
}

// Wrap the original streaming callback
originalCallback := genOpts.Stream
var accumulatedText string

wrappedCallback := func(ctx context.Context, chunk *ModelResponseChunk) error {
// Accumulate text
accumulatedText += chunk.Text()

// Store accumulated text in chunk metadata
if chunk.Metadata == nil {
chunk.Metadata = make(map[string]any)
}
chunk.Metadata["accumulatedText"] = accumulatedText

// Try to parse using formatter if available
if formatHandler != nil {
parsedData, err := formatHandler.ParseChunk(chunk, accumulatedText)
if err == nil && parsedData != nil {
// Try to convert to the expected type
jsonBytes, err := json.Marshal(parsedData)
if err == nil {
var partialValue Out
if err := json.Unmarshal(jsonBytes, &partialValue); err == nil {
// Store parsed data in metadata
chunk.Metadata["parsedData"] = partialValue
}
}
}
}

// Call the original callback with the enhanced chunk
return originalCallback(ctx, chunk)
}

// Replace the streaming callback in options
newOpts := make([]GenerateOption, 0, len(opts))
for _, opt := range opts {
// Skip the original streaming option by checking if it's an ExecutionOption with Stream
if execOpt, ok := opt.(ExecutionOption); ok {
var tempOpts executionOptions
execOpt.applyExecution(&tempOpts)
if tempOpts.Stream != nil {
// Skip this option as it has streaming
continue
}
}
newOpts = append(newOpts, opt)
}
newOpts = append(newOpts, WithStreaming(wrappedCallback))
opts = newOpts
}

resp, err := Generate(ctx, r, opts...)
if err != nil {
return nil, nil, err
Expand Down
Loading