Skip to content

Commit badb126

Browse files
authored
enhance: sdk: list full model objects, instead of just names (#919)
Signed-off-by: Grant Linville <[email protected]>
1 parent eb03680 commit badb126

File tree

6 files changed

+33
-17
lines changed

6 files changed

+33
-17
lines changed

pkg/cli/gptscript.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,10 @@ func (r *GPTScript) listModels(ctx context.Context, gptScript *gptscript.GPTScri
276276
if err != nil {
277277
return err
278278
}
279-
fmt.Println(strings.Join(models, "\n"))
279+
280+
for _, model := range models {
281+
fmt.Println(model.ID)
282+
}
280283
return nil
281284
}
282285

pkg/gptscript/gptscript.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"slices"
1111
"strings"
1212

13+
openai2 "github.com/gptscript-ai/chat-completion-client"
1314
"github.com/gptscript-ai/gptscript/pkg/builtin"
1415
"github.com/gptscript-ai/gptscript/pkg/cache"
1516
"github.com/gptscript-ai/gptscript/pkg/config"
@@ -275,7 +276,7 @@ func (g *GPTScript) ListTools(_ context.Context, prg types.Program) []types.Tool
275276
return prg.TopLevelTools()
276277
}
277278

278-
func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]string, error) {
279+
func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]openai2.Model, error) {
279280
return g.Registry.ListModels(ctx, providers...)
280281
}
281282

pkg/llm/registry.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99

1010
"github.com/google/uuid"
11+
openai2 "github.com/gptscript-ai/chat-completion-client"
1112
"github.com/gptscript-ai/gptscript/pkg/env"
1213
"github.com/gptscript-ai/gptscript/pkg/openai"
1314
"github.com/gptscript-ai/gptscript/pkg/remote"
@@ -16,7 +17,7 @@ import (
1617

1718
type Client interface {
1819
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
19-
ListModels(ctx context.Context, providers ...string) (result []string, _ error)
20+
ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error)
2021
Supports(ctx context.Context, modelName string) (bool, error)
2122
}
2223

@@ -38,15 +39,17 @@ func (r *Registry) AddClient(client Client) error {
3839
return nil
3940
}
4041

41-
func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
42+
func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) {
4243
for _, v := range r.clients {
4344
models, err := v.ListModels(ctx, providers...)
4445
if err != nil {
4546
return nil, err
4647
}
4748
result = append(result, models...)
4849
}
49-
sort.Strings(result)
50+
sort.Slice(result, func(i, j int) bool {
51+
return result[i].ID < result[j].ID
52+
})
5053
return result, nil
5154
}
5255

pkg/openai/client.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,15 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
157157
return false, InvalidAuthError{}
158158
}
159159

160-
return slices.Contains(models, modelName), nil
160+
for _, model := range models {
161+
if model.ID == modelName {
162+
return true, nil
163+
}
164+
}
165+
return false, nil
161166
}
162167

163-
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
168+
func (c *Client) ListModels(ctx context.Context, providers ...string) ([]openai.Model, error) {
164169
// Only serve if providers is empty or "" is in the list
165170
if len(providers) != 0 && !slices.Contains(providers, "") {
166171
return nil, nil
@@ -179,11 +184,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
179184
if err != nil {
180185
return nil, err
181186
}
182-
for _, model := range models.Models {
183-
result = append(result, model.ID)
184-
}
185-
sort.Strings(result)
186-
return result, nil
187+
sort.Slice(models.Models, func(i, j int) bool {
188+
return models.Models[i].ID < models.Models[j].ID
189+
})
190+
return models.Models, nil
187191
}
188192

189193
func (c *Client) cacheKey(request openai.ChatCompletionRequest) any {

pkg/remote/remote.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010
"sync"
1111

12+
openai2 "github.com/gptscript-ai/chat-completion-client"
1213
"github.com/gptscript-ai/gptscript/pkg/cache"
1314
"github.com/gptscript-ai/gptscript/pkg/credentials"
1415
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -62,7 +63,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
6263
return client.Call(ctx, messageRequest, env, status)
6364
}
6465

65-
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
66+
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) {
6667
for _, provider := range providers {
6768
client, err := c.load(ctx, provider)
6869
if err != nil {
@@ -72,12 +73,16 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
7273
if err != nil {
7374
return nil, err
7475
}
75-
for _, model := range models {
76-
result = append(result, model+" from "+provider)
76+
for i := range models {
77+
models[i].ID = fmt.Sprintf("%s from %s", models[i].ID, provider)
7778
}
79+
80+
result = append(result, models...)
7881
}
7982

80-
sort.Strings(result)
83+
sort.Slice(result, func(i, j int) bool {
84+
return result[i].ID < result[j].ID
85+
})
8186
return
8287
}
8388

pkg/sdkserver/routes.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func (s *server) listModels(w http.ResponseWriter, r *http.Request) {
145145
return
146146
}
147147

148-
writeResponse(logger, w, map[string]any{"stdout": strings.Join(out, "\n")})
148+
writeResponse(logger, w, map[string]any{"stdout": out})
149149
}
150150

151151
// execHandler is a general handler for executing tools with gptscript. This is mainly responsible for parsing the request body.

0 commit comments

Comments
 (0)