@@ -172,11 +172,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
172
172
return resp , err
173
173
}
174
174
175
- if state == nil || state .StartContinuation {
176
- if state != nil {
177
- state = state .WithResumeInput (& input )
178
- input = state .InputContextContinuationInput
179
- }
175
+ if state == nil {
180
176
state , err = r .start (callCtx , state , monitor , env , input )
181
177
if err != nil {
182
178
return resp , err
@@ -186,11 +182,9 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
186
182
state .ResumeInput = & input
187
183
}
188
184
189
- if ! state .StartContinuation {
190
- state , err = r .resume (callCtx , monitor , env , state )
191
- if err != nil {
192
- return resp , err
193
- }
185
+ state , err = r .resume (callCtx , monitor , env , state )
186
+ if err != nil {
187
+ return resp , err
194
188
}
195
189
196
190
if state .Result != nil {
@@ -260,6 +254,10 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
260
254
targetArgs := prg .ToolSet [ref .ToolID ].Arguments
261
255
targetKeys := map [string ]string {}
262
256
257
+ if ref .Arg == "*" {
258
+ return input , nil
259
+ }
260
+
263
261
if targetArgs == nil {
264
262
return "" , nil
265
263
}
@@ -331,24 +329,10 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
331
329
return string (output ), err
332
330
}
333
331
334
- func (r * Runner ) getContext (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ * State , _ error ) {
332
+ func (r * Runner ) getContext (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ error ) {
335
333
toolRefs , err := callCtx .Tool .GetContextTools (* callCtx .Program )
336
334
if err != nil {
337
- return nil , nil , err
338
- }
339
-
340
- var newState * State
341
- if state != nil {
342
- cp := * state
343
- newState = & cp
344
- if newState .InputContextContinuation != nil {
345
- newState .InputContexts = nil
346
- newState .InputContextContinuation = nil
347
- newState .InputContextContinuationInput = ""
348
- newState .ResumeInput = state .InputContextContinuationResumeInput
349
-
350
- input = state .InputContextContinuationInput
351
- }
335
+ return nil , err
352
336
}
353
337
354
338
for i , toolRef := range toolRefs {
@@ -359,47 +343,31 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
359
343
360
344
contextInput , err := getToolRefInput (callCtx .Program , toolRef , input )
361
345
if err != nil {
362
- return nil , nil , err
346
+ return nil , err
363
347
}
364
348
365
349
var content * State
366
- if state != nil && state .InputContextContinuation != nil {
367
- content , err = r .subCallResume (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , "" , state .InputContextContinuation .WithResumeInput (state .ResumeInput ), engine .ContextToolCategory )
368
- } else {
369
- content , err = r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
370
- }
350
+ content , err = r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
371
351
if err != nil {
372
- return nil , nil , err
352
+ return nil , err
373
353
}
374
354
if content .Continuation != nil {
375
- if newState == nil {
376
- newState = & State {}
377
- }
378
- newState .InputContexts = result
379
- newState .InputContextContinuation = content
380
- newState .InputContextContinuationInput = input
381
- if state != nil {
382
- newState .InputContextContinuationResumeInput = state .ResumeInput
383
- }
384
- return nil , newState , nil
355
+ return nil , fmt .Errorf ("invalid state: context tool [%s] can not result in a continuation" , toolRef .ToolID )
385
356
}
386
357
result = append (result , engine.InputContext {
387
358
ToolID : toolRef .ToolID ,
388
359
Content : * content .Result ,
389
360
})
390
361
}
391
362
392
- return result , newState , nil
363
+ return result , nil
393
364
}
394
365
395
366
func (r * Runner ) call (callCtx engine.Context , monitor Monitor , env []string , input string ) (* State , error ) {
396
367
result , err := r .start (callCtx , nil , monitor , env , input )
397
368
if err != nil {
398
369
return nil , err
399
370
}
400
- if result .StartContinuation {
401
- return result , nil
402
- }
403
371
return r .resume (callCtx , monitor , env , result )
404
372
}
405
373
@@ -431,15 +399,10 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
431
399
}
432
400
}
433
401
434
- var newState * State
435
- callCtx .InputContext , newState , err = r .getContext (callCtx , state , monitor , env , input )
402
+ callCtx .InputContext , err = r .getContext (callCtx , state , monitor , env , input )
436
403
if err != nil {
437
404
return nil , err
438
405
}
439
- if newState != nil && newState .InputContextContinuation != nil {
440
- newState .StartContinuation = true
441
- return newState , nil
442
- }
443
406
444
407
e := engine.Engine {
445
408
Model : r .c ,
@@ -489,11 +452,7 @@ type State struct {
489
452
SubCalls []SubCallResult `json:"subCalls,omitempty"`
490
453
SubCallID string `json:"subCallID,omitempty"`
491
454
492
- InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
493
- InputContextContinuation * State `json:"inputContextContinuation,omitempty"`
494
- InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
495
- InputContextContinuationResumeInput * string `json:"inputContextContinuationResumeInput,omitempty"`
496
- StartContinuation bool `json:"startContinuation,omitempty"`
455
+ InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
497
456
}
498
457
499
458
func (s State ) WithResumeInput (input * string ) * State {
@@ -506,10 +465,6 @@ func (s State) ContinuationContentToolID() (string, error) {
506
465
return s .ContinuationToolID , nil
507
466
}
508
467
509
- if s .InputContextContinuation != nil {
510
- return s .InputContextContinuation .ContinuationContentToolID ()
511
- }
512
-
513
468
for _ , subCall := range s .SubCalls {
514
469
if s .SubCallID == subCall .CallID {
515
470
return subCall .State .ContinuationContentToolID ()
@@ -523,10 +478,6 @@ func (s State) ContinuationContent() (string, error) {
523
478
return * s .Continuation .Result , nil
524
479
}
525
480
526
- if s .InputContextContinuation != nil {
527
- return s .InputContextContinuation .ContinuationContent ()
528
- }
529
-
530
481
for _ , subCall := range s .SubCalls {
531
482
if s .SubCallID == subCall .CallID {
532
483
return subCall .State .ContinuationContent ()
@@ -545,10 +496,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
545
496
retState , retErr = r .handleOutput (callCtx , monitor , env , retState , retErr )
546
497
}()
547
498
548
- if state .StartContinuation {
549
- return nil , fmt .Errorf ("invalid state, resume should not have StartContinuation set to true" )
550
- }
551
-
552
499
if state .Continuation == nil {
553
500
return nil , errors .New ("invalid state, resume should have Continuation data" )
554
501
}
@@ -653,8 +600,12 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
653
600
contentInput = state .Continuation .State .Input
654
601
}
655
602
656
- callCtx .InputContext , state , err = r .getContext (callCtx , state , monitor , env , contentInput )
657
- if err != nil || state .InputContextContinuation != nil {
603
+ if state .ResumeInput != nil {
604
+ contentInput = * state .ResumeInput
605
+ }
606
+
607
+ callCtx .InputContext , err = r .getContext (callCtx , state , monitor , env , contentInput )
608
+ if err != nil {
658
609
return state , err
659
610
}
660
611
@@ -764,10 +715,6 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
764
715
callCtx .LastReturn = state .Continuation
765
716
}
766
717
767
- if state .InputContextContinuation != nil {
768
- return state , nil , nil
769
- }
770
-
771
718
if state .SubCallID != "" {
772
719
if state .ResumeInput == nil {
773
720
return nil , nil , fmt .Errorf ("invalid state, input must be set for sub call continuation on callID [%s]" , state .SubCallID )
0 commit comments