Skip to content

Commit c60a501

Browse files
committed
enhance: get more information about models
Signed-off-by: Grant Linville <[email protected]>
1 parent 64f1b28 commit c60a501

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

gptscript/gptscript.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from gptscript.datasets import DatasetElementMeta, DatasetElement, DatasetMeta
1212
from gptscript.fileinfo import FileInfo
1313
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
14+
from gptscript.openai import Model
1415
from gptscript.opts import GlobalOptions
1516
from gptscript.prompt import PromptResponse
1617
from gptscript.run import Run, RunBasicCommand, Options
@@ -164,16 +165,17 @@ async def _run_basic_command(self, sub_command: str, request_body: Any = None):
164165
async def version(self) -> str:
165166
return await self._run_basic_command("version")
166167

167-
async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]:
168+
async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[Model]:
168169
if self.opts.DefaultModelProvider != "":
169170
if providers is None:
170171
providers = []
171172
providers.append(self.opts.DefaultModelProvider)
172173

173-
return (await self._run_basic_command(
174+
res = await self._run_basic_command(
174175
"list-models",
175176
{"providers": providers, "credentialOverrides": credential_overrides}
176-
)).split("\n")
177+
)
178+
return [Model(**model) for model in json.loads(res)]
177179

178180
async def list_credentials(self, contexts: List[str] = None, all_contexts: bool = False) -> list[Credential] | str:
179181
if contexts is None:

gptscript/openai.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from pydantic import BaseModel, conlist
2+
from typing import Any, Dict, Optional
3+
4+
5+
class Permission(BaseModel):
6+
created: int
7+
id: str
8+
object: str
9+
allow_create_engine: bool
10+
allow_sampling: bool
11+
allow_logprobs: bool
12+
allow_search_indices: bool
13+
allow_view: bool
14+
allow_fine_tuning: bool
15+
organization: str
16+
group: Any
17+
is_blocking: bool
18+
19+
20+
class Model(BaseModel):
21+
created: Optional[int]
22+
id: str
23+
object: str
24+
owned_by: str
25+
permission: Optional[conlist(Permission)]
26+
root: Optional[str]
27+
parent: Optional[str]
28+
metadata: Optional[Dict[str, str]]

tests/test_gptscript.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ async def test_list_models_from_provider(gptscript):
126126
)
127127
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
128128
for model in models:
129-
assert model.startswith("claude-3-"), "Unexpected model name"
130-
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
129+
assert model.id.startswith("claude-3-"), "Unexpected model name"
130+
assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
131131

132132

133133
@pytest.mark.asyncio
@@ -140,8 +140,8 @@ async def test_list_models_from_default_provider():
140140
)
141141
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
142142
for model in models:
143-
assert model.startswith("claude-3-"), "Unexpected model name"
144-
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
143+
assert model.id.startswith("claude-3-"), "Unexpected model name"
144+
assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
145145
finally:
146146
g.close()
147147

0 commit comments

Comments
 (0)