Skip to content

Commit cd76120

Browse files
authored
feat: add dataset functions (#59)
Signed-off-by: Grant Linville <[email protected]>
1 parent 97d819c commit cd76120

File tree

6 files changed

+151
-2
lines changed

6 files changed

+151
-2
lines changed

gptscript/datasets.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Dict
2+
from pydantic import BaseModel
3+
4+
class DatasetElementMeta(BaseModel):
5+
name: str
6+
description: str
7+
8+
9+
class DatasetElement(BaseModel):
10+
name: str
11+
description: str
12+
contents: str
13+
14+
15+
class DatasetMeta(BaseModel):
16+
id: str
17+
name: str
18+
description: str
19+
20+
21+
class Dataset(BaseModel):
22+
id: str
23+
name: str
24+
description: str
25+
elements: Dict[str, DatasetElementMeta]

gptscript/gptscript.py

+81
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from gptscript.confirm import AuthResponse
99
from gptscript.credentials import Credential, to_credential
10+
from gptscript.datasets import DatasetMeta, Dataset, DatasetElementMeta, DatasetElement
1011
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
1112
from gptscript.opts import GlobalOptions
1213
from gptscript.prompt import PromptResponse
@@ -210,6 +211,86 @@ async def delete_credential(self, context: str = "default", name: str = "") -> s
210211
{"context": [context], "name": name}
211212
)
212213

214+
async def list_datasets(self, workspace: str) -> List[DatasetMeta]:
215+
if workspace == "":
216+
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]
217+
218+
res = await self._run_basic_command(
219+
"datasets",
220+
{"input": "{}", "workspace": workspace, "datasetToolRepo": self.opts.DatasetToolRepo}
221+
)
222+
return [DatasetMeta.model_validate(d) for d in json.loads(res)]
223+
224+
async def create_dataset(self, workspace: str, name: str, description: str = "") -> Dataset:
225+
if workspace == "":
226+
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]
227+
228+
if name == "":
229+
raise ValueError("name cannot be empty")
230+
231+
res = await self._run_basic_command(
232+
"datasets/create",
233+
{"input": json.dumps({"datasetName": name, "datasetDescription": description}),
234+
"workspace": workspace,
235+
"datasetToolRepo": self.opts.DatasetToolRepo}
236+
)
237+
return Dataset.model_validate_json(res)
238+
239+
async def add_dataset_element(self, workspace: str, datasetID: str, elementName: str, elementContent: str,
240+
elementDescription: str = "") -> DatasetElementMeta:
241+
if workspace == "":
242+
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]
243+
244+
if datasetID == "":
245+
raise ValueError("datasetID cannot be empty")
246+
elif elementName == "":
247+
raise ValueError("elementName cannot be empty")
248+
elif elementContent == "":
249+
raise ValueError("elementContent cannot be empty")
250+
251+
res = await self._run_basic_command(
252+
"datasets/add-element",
253+
{"input": json.dumps({"datasetID": datasetID,
254+
"elementName": elementName,
255+
"elementContent": elementContent,
256+
"elementDescription": elementDescription}),
257+
"workspace": workspace,
258+
"datasetToolRepo": self.opts.DatasetToolRepo}
259+
)
260+
return DatasetElementMeta.model_validate_json(res)
261+
262+
async def list_dataset_elements(self, workspace: str, datasetID: str) -> List[DatasetElementMeta]:
263+
if workspace == "":
264+
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]
265+
266+
if datasetID == "":
267+
raise ValueError("datasetID cannot be empty")
268+
269+
res = await self._run_basic_command(
270+
"datasets/list-elements",
271+
{"input": json.dumps({"datasetID": datasetID}),
272+
"workspace": workspace,
273+
"datasetToolRepo": self.opts.DatasetToolRepo}
274+
)
275+
return [DatasetElementMeta.model_validate(d) for d in json.loads(res)]
276+
277+
async def get_dataset_element(self, workspace: str, datasetID: str, elementName: str) -> DatasetElement:
278+
if workspace == "":
279+
workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"]
280+
281+
if datasetID == "":
282+
raise ValueError("datasetID cannot be empty")
283+
elif elementName == "":
284+
raise ValueError("elementName cannot be empty")
285+
286+
res = await self._run_basic_command(
287+
"datasets/get-element",
288+
{"input": json.dumps({"datasetID": datasetID, "element": elementName}),
289+
"workspace": workspace,
290+
"datasetToolRepo": self.opts.DatasetToolRepo}
291+
)
292+
return DatasetElement.model_validate_json(res)
293+
213294

214295
def _get_command():
215296
if os.getenv("GPTSCRIPT_BIN") is not None:

gptscript/opts.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
defaultModelProvider: str = "",
1313
defaultModel: str = "",
1414
cacheDir: str = "",
15+
datasetToolRepo: str = "",
1516
env: list[str] = None,
1617
):
1718
self.URL = url
@@ -21,6 +22,7 @@ def __init__(
2122
self.DefaultModel = defaultModel
2223
self.DefaultModelProvider = defaultModelProvider
2324
self.CacheDir = cacheDir
25+
self.DatasetToolRepo = datasetToolRepo
2426
if env is None:
2527
env = [f"{k}={v}" for k, v in os.environ.items()]
2628
elif isinstance(env, dict):
@@ -38,6 +40,7 @@ def merge(self, other: Self) -> Self:
3840
cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel
3941
cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider
4042
cp.CacheDir = other.CacheDir if other.CacheDir != "" else self.CacheDir
43+
cp.DatasetToolRepo = other.DatasetToolRepo if other.DatasetToolRepo != "" else self.DatasetToolRepo
4144
cp.Env = (other.Env or [])
4245
cp.Env.extend(self.Env or [])
4346
return cp
@@ -77,8 +80,9 @@ def __init__(self,
7780
defaultModelProvider: str = "",
7881
defaultModel: str = "",
7982
cacheDir: str = "",
83+
datasetToolDir: str = "",
8084
):
81-
super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, env)
85+
super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, datasetToolDir, env)
8286
self.input = input
8387
self.disableCache = disableCache
8488
self.subTool = subTool

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ setuptools==69.1.1
1010
twine==5.0.0
1111
build==1.1.1
1212
httpx==0.27.0
13-
pywin32==306; sys_platform == 'win32'
13+
pydantic==2.9.2
14+
pywin32==306; sys_platform == 'win32'

tests/test_gptscript.py

+37
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import platform
66
import subprocess
7+
import tempfile
78
from datetime import datetime, timedelta, timezone
89
from time import sleep
910

@@ -755,3 +756,39 @@ async def test_credentials(gptscript):
755756

756757
res = await gptscript.delete_credential(name=name)
757758
assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res
759+
760+
@pytest.mark.asyncio
761+
async def test_datasets(gptscript):
762+
with tempfile.TemporaryDirectory(prefix="py-gptscript_") as tempdir:
763+
dataset_name = str(os.urandom(8).hex())
764+
765+
# Create dataset
766+
dataset = await gptscript.create_dataset(tempdir, dataset_name, "this is a test dataset")
767+
assert dataset.id != "", "Expected dataset id to be set"
768+
assert dataset.name == dataset_name, "Expected dataset name to match"
769+
assert dataset.description == "this is a test dataset", "Expected dataset description to match"
770+
assert len(dataset.elements) == 0, "Expected dataset elements to be empty"
771+
772+
# Add an element
773+
element_meta = await gptscript.add_dataset_element(tempdir, dataset.id, "element1", "element1 contents", "element1 description")
774+
assert element_meta.name == "element1", "Expected element name to match"
775+
assert element_meta.description == "element1 description", "Expected element description to match"
776+
777+
# Get the element
778+
element = await gptscript.get_dataset_element(tempdir, dataset.id, "element1")
779+
assert element.name == "element1", "Expected element name to match"
780+
assert element.contents == "element1 contents", "Expected element contents to match"
781+
assert element.description == "element1 description", "Expected element description to match"
782+
783+
# List elements in the dataset
784+
elements = await gptscript.list_dataset_elements(tempdir, dataset.id)
785+
assert len(elements) == 1, "Expected one element in the dataset"
786+
assert elements[0].name == "element1", "Expected element name to match"
787+
assert elements[0].description == "element1 description", "Expected element description to match"
788+
789+
# List datasets
790+
datasets = await gptscript.list_datasets(tempdir)
791+
assert len(datasets) > 0, "Expected at least one dataset"
792+
assert datasets[0].id == dataset.id, "Expected dataset id to match"
793+
assert datasets[0].name == dataset_name, "Expected dataset name to match"
794+
assert datasets[0].description == "this is a test dataset", "Expected dataset description to match"

tox.ini

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ deps =
66
httpx
77
pytest
88
pytest-asyncio
9+
pydantic
910

1011
passenv =
1112
OPENAI_API_KEY

0 commit comments

Comments
 (0)