Skip to content

Commit 0cebee3

Browse files
authored
chore: update for dataset rewrite (#66)
Signed-off-by: Grant Linville <[email protected]>
1 parent e9f3b2f commit 0cebee3

File tree

4 files changed

+68
-124
lines changed

4 files changed

+68
-124
lines changed

gptscript/datasets.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,31 @@
11
import base64
2-
from typing import Dict
32
from pydantic import BaseModel, field_serializer, field_validator, BeforeValidator
43

54

5+
class DatasetMeta(BaseModel):
6+
id: str
7+
name: str
8+
description: str
9+
10+
611
class DatasetElementMeta(BaseModel):
712
name: str
813
description: str
914

1015

1116
class DatasetElement(BaseModel):
1217
name: str
13-
description: str
14-
contents: bytes
18+
description: str = ""
19+
contents: str = ""
20+
binaryContents: bytes = b""
1521

16-
@field_serializer("contents")
22+
@field_serializer("binaryContents")
1723
def serialize_contents(self, value: bytes) -> str:
1824
return base64.b64encode(value).decode("utf-8")
1925

20-
@field_validator("contents", mode="before")
26+
@field_validator("binaryContents", mode="before")
2127
def deserialize_contents(cls, value) -> bytes:
2228
if isinstance(value, str):
2329
return base64.b64decode(value)
2430
return value
2531

26-
27-
class DatasetMeta(BaseModel):
28-
id: str
29-
name: str
30-
description: str
31-
32-
33-
class Dataset(BaseModel):
34-
id: str
35-
name: str
36-
description: str
37-
elements: Dict[str, DatasetElementMeta]

gptscript/gptscript.py

+25-74
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from gptscript.confirm import AuthResponse
1010
from gptscript.credentials import Credential, to_credential
11-
from gptscript.datasets import DatasetMeta, Dataset, DatasetElementMeta, DatasetElement
11+
from gptscript.datasets import DatasetElementMeta, DatasetElement, DatasetMeta
1212
from gptscript.fileinfo import FileInfo
1313
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
1414
from gptscript.opts import GlobalOptions
@@ -213,109 +213,58 @@ async def delete_credential(self, context: str = "default", name: str = "") -> s
213213
{"context": [context], "name": name}
214214
)
215215

216-
async def list_datasets(self, workspace_id: str) -> List[DatasetMeta]:
217-
if workspace_id == "":
218-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
219-
216+
# list_datasets returns an array of dataset IDs
217+
async def list_datasets(self) -> List[DatasetMeta]:
220218
res = await self._run_basic_command(
221219
"datasets",
222-
{"input": "{}", "workspaceID": workspace_id, "datasetToolRepo": self.opts.DatasetToolRepo,
223-
"env": self.opts.Env}
224-
)
225-
return [DatasetMeta.model_validate(d) for d in json.loads(res)]
226-
227-
async def create_dataset(self, workspace_id: str, name: str, description: str = "") -> Dataset:
228-
if workspace_id == "":
229-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
230-
231-
if name == "":
232-
raise ValueError("name cannot be empty")
233-
234-
res = await self._run_basic_command(
235-
"datasets/create",
236220
{
237-
"input": json.dumps({"datasetName": name, "datasetDescription": description}),
238-
"workspaceID": workspace_id,
239-
"datasetToolRepo": self.opts.DatasetToolRepo,
240-
"env": self.opts.Env,
241-
}
242-
)
243-
return Dataset.model_validate_json(res)
244-
245-
async def add_dataset_element(self, workspace_id: str, datasetID: str, elementName: str, elementContent: bytes,
246-
elementDescription: str = "") -> DatasetElementMeta:
247-
if workspace_id == "":
248-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
249-
250-
if datasetID == "":
251-
raise ValueError("datasetID cannot be empty")
252-
elif elementName == "":
253-
raise ValueError("elementName cannot be empty")
254-
elif not elementContent:
255-
raise ValueError("elementContent cannot be empty")
256-
257-
res = await self._run_basic_command(
258-
"datasets/add-element",
259-
{
260-
"input": json.dumps({
261-
"datasetID": datasetID,
262-
"elementName": elementName,
263-
"elementContent": base64.b64encode(elementContent).decode("utf-8"),
264-
"elementDescription": elementDescription,
265-
}),
266-
"workspaceID": workspace_id,
267-
"datasetToolRepo": self.opts.DatasetToolRepo,
221+
"input": "{}",
222+
"datasetTool": self.opts.DatasetTool,
268223
"env": self.opts.Env
269224
}
270225
)
271-
return DatasetElementMeta.model_validate_json(res)
272-
273-
async def add_dataset_elements(self, workspace_id: str, datasetID: str, elements: List[DatasetElement]) -> str:
274-
if workspace_id == "":
275-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
226+
return [DatasetMeta.model_validate(d) for d in json.loads(res)]
276227

277-
if datasetID == "":
278-
raise ValueError("datasetID cannot be empty")
279-
elif not elements:
228+
async def add_dataset_elements(
229+
self,
230+
elements: List[DatasetElement],
231+
datasetID: str = "",
232+
name: str = "",
233+
description: str = ""
234+
) -> str:
235+
if not elements:
280236
raise ValueError("elements cannot be empty")
281237

282238
res = await self._run_basic_command(
283239
"datasets/add-elements",
284240
{
285241
"input": json.dumps({
286242
"datasetID": datasetID,
243+
"name": name,
244+
"description": description,
287245
"elements": [element.model_dump() for element in elements],
288246
}),
289-
"workspaceID": workspace_id,
290-
"datasetToolRepo": self.opts.DatasetToolRepo,
247+
"datasetTool": self.opts.DatasetTool,
291248
"env": self.opts.Env
292249
}
293250
)
294251
return res
295252

296-
297-
async def list_dataset_elements(self, workspace_id: str, datasetID: str) -> List[DatasetElementMeta]:
298-
if workspace_id == "":
299-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
300-
253+
async def list_dataset_elements(self, datasetID: str) -> List[DatasetElementMeta]:
301254
if datasetID == "":
302255
raise ValueError("datasetID cannot be empty")
303256

304257
res = await self._run_basic_command(
305258
"datasets/list-elements",
306259
{
307260
"input": json.dumps({"datasetID": datasetID}),
308-
"workspaceID": workspace_id,
309-
"datasetToolRepo": self.opts.DatasetToolRepo,
261+
"datasetTool": self.opts.DatasetTool,
310262
"env": self.opts.Env
311263
}
312264
)
313265
return [DatasetElementMeta.model_validate(d) for d in json.loads(res)]
314266

315-
async def get_dataset_element(self, workspace_id: str, datasetID: str, elementName: str) -> DatasetElement:
316-
if workspace_id == "":
317-
workspace_id = os.environ["GPTSCRIPT_WORKSPACE_ID"]
318-
267+
async def get_dataset_element(self, datasetID: str, elementName: str) -> DatasetElement:
319268
if datasetID == "":
320269
raise ValueError("datasetID cannot be empty")
321270
elif elementName == "":
@@ -324,9 +273,11 @@ async def get_dataset_element(self, workspace_id: str, datasetID: str, elementNa
324273
res = await self._run_basic_command(
325274
"datasets/get-element",
326275
{
327-
"input": json.dumps({"datasetID": datasetID, "element": elementName}),
328-
"workspaceID": workspace_id,
329-
"datasetToolRepo": self.opts.DatasetToolRepo,
276+
"input": json.dumps({
277+
"datasetID": datasetID,
278+
"name": elementName,
279+
}),
280+
"datasetTool": self.opts.DatasetTool,
330281
"env": self.opts.Env,
331282
}
332283
)

gptscript/opts.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(
1212
defaultModelProvider: str = "",
1313
defaultModel: str = "",
1414
cacheDir: str = "",
15-
datasetToolRepo: str = "",
15+
datasetTool: str = "",
1616
workspaceTool: str = "",
1717
env: list[str] = None,
1818
):
@@ -23,7 +23,7 @@ def __init__(
2323
self.DefaultModel = defaultModel
2424
self.DefaultModelProvider = defaultModelProvider
2525
self.CacheDir = cacheDir
26-
self.DatasetToolRepo = datasetToolRepo
26+
self.DatasetTool = datasetTool
2727
self.WorkspaceTool = workspaceTool
2828
if env is None:
2929
env = [f"{k}={v}" for k, v in os.environ.items()]
@@ -42,7 +42,7 @@ def merge(self, other: Self) -> Self:
4242
cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel
4343
cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider
4444
cp.CacheDir = other.CacheDir if other.CacheDir != "" else self.CacheDir
45-
cp.DatasetToolRepo = other.DatasetToolRepo if other.DatasetToolRepo != "" else self.DatasetToolRepo
45+
cp.DatasetTool = other.DatasetTool if other.DatasetTool != "" else self.DatasetTool
4646
cp.WorkspaceTool = other.WorkspaceTool if other.WorkspaceTool != "" else self.WorkspaceTool
4747
cp.Env = (other.Env or [])
4848
cp.Env.extend(self.Env or [])

tests/test_gptscript.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -760,60 +760,59 @@ async def test_credentials(gptscript):
760760

761761
@pytest.mark.asyncio
762762
async def test_datasets(gptscript):
763-
workspace_id = await gptscript.create_workspace("directory")
764-
dataset_name = str(os.urandom(8).hex())
763+
os.environ["GPTSCRIPT_WORKSPACE_ID"] = await gptscript.create_workspace("directory")
764+
765+
new_client = GPTScript(GlobalOptions(
766+
apiKey=os.getenv("OPENAI_API_KEY"),
767+
env=[f"{k}={v}" for k, v in os.environ.items()],
768+
))
765769

766770
# Create dataset
767-
dataset = await gptscript.create_dataset(workspace_id, dataset_name, "this is a test dataset")
768-
assert dataset.id != "", "Expected dataset id to be set"
769-
assert dataset.name == dataset_name, "Expected dataset name to match"
770-
assert dataset.description == "this is a test dataset", "Expected dataset description to match"
771-
assert len(dataset.elements) == 0, "Expected dataset elements to be empty"
772-
773-
# Add an element
774-
element_meta = await gptscript.add_dataset_element(workspace_id, dataset.id, "element1", b"element1 contents",
775-
"element1 description")
776-
assert element_meta.name == "element1", "Expected element name to match"
777-
assert element_meta.description == "element1 description", "Expected element description to match"
771+
dataset_id = await new_client.add_dataset_elements([
772+
DatasetElement(name="element1", contents="element1 contents", description="element1 description"),
773+
DatasetElement(name="element2", binaryContents=b"element2 contents", description="element2 description"),
774+
], name="test-dataset", description="test dataset description")
778775

779776
# Add two more elements
780-
await gptscript.add_dataset_elements(workspace_id, dataset.id, [
781-
DatasetElement(name="element2", contents=b"element2 contents", description="element2 description"),
782-
DatasetElement(name="element3", contents=b"element3 contents", description="element3 description"),
783-
])
777+
await new_client.add_dataset_elements([
778+
DatasetElement(name="element3", contents="element3 contents", description="element3 description"),
779+
DatasetElement(name="element4", contents="element3 contents", description="element4 description"),
780+
], datasetID=dataset_id)
784781

785782
# Get the elements
786-
e1 = await gptscript.get_dataset_element(workspace_id, dataset.id, "element1")
783+
e1 = await new_client.get_dataset_element(dataset_id, "element1")
787784
assert e1.name == "element1", "Expected element name to match"
788-
assert e1.contents == b"element1 contents", "Expected element contents to match"
785+
assert e1.contents == "element1 contents", "Expected element contents to match"
789786
assert e1.description == "element1 description", "Expected element description to match"
790-
e2 = await gptscript.get_dataset_element(workspace_id, dataset.id, "element2")
787+
e2 = await new_client.get_dataset_element(dataset_id, "element2")
791788
assert e2.name == "element2", "Expected element name to match"
792-
assert e2.contents == b"element2 contents", "Expected element contents to match"
789+
assert e2.binaryContents == b"element2 contents", "Expected element contents to match"
793790
assert e2.description == "element2 description", "Expected element description to match"
794-
e3 = await gptscript.get_dataset_element(workspace_id, dataset.id, "element3")
791+
e3 = await new_client.get_dataset_element(dataset_id, "element3")
795792
assert e3.name == "element3", "Expected element name to match"
796-
assert e3.contents == b"element3 contents", "Expected element contents to match"
793+
assert e3.contents == "element3 contents", "Expected element contents to match"
797794
assert e3.description == "element3 description", "Expected element description to match"
798795

799796
# List elements in the dataset
800-
elements = await gptscript.list_dataset_elements(workspace_id, dataset.id)
801-
assert len(elements) == 3, "Expected one element in the dataset"
797+
elements = await new_client.list_dataset_elements(dataset_id)
798+
assert len(elements) == 4, "Expected four elements in the dataset"
802799
assert elements[0].name == "element1", "Expected element name to match"
803800
assert elements[0].description == "element1 description", "Expected element description to match"
804801
assert elements[1].name == "element2", "Expected element name to match"
805802
assert elements[1].description == "element2 description", "Expected element description to match"
806803
assert elements[2].name == "element3", "Expected element name to match"
807804
assert elements[2].description == "element3 description", "Expected element description to match"
805+
assert elements[3].name == "element4", "Expected element name to match"
806+
assert elements[3].description == "element4 description", "Expected element description to match"
808807

809808
# List datasets
810-
datasets = await gptscript.list_datasets(workspace_id)
809+
datasets = await new_client.list_datasets()
811810
assert len(datasets) > 0, "Expected at least one dataset"
812-
assert datasets[0].id == dataset.id, "Expected dataset id to match"
813-
assert datasets[0].name == dataset_name, "Expected dataset name to match"
814-
assert datasets[0].description == "this is a test dataset", "Expected dataset description to match"
811+
assert datasets[0].id == dataset_id, "Expected dataset id to match"
812+
assert datasets[0].name == "test-dataset", "Expected dataset name to match"
813+
assert datasets[0].description == "test dataset description", "Expected dataset description to match"
815814

816-
await gptscript.delete_workspace(workspace_id)
815+
await gptscript.delete_workspace(os.environ["GPTSCRIPT_WORKSPACE_ID"])
817816

818817

819818
@pytest.mark.asyncio

0 commit comments

Comments
 (0)