Skip to content

Commit bd52a96

Browse files
ilonginrlamyivan
authored
Refactor DataChain.from_storage() to use new listing generator (iterative#294)
* added generator to list a bucket * added test for async to sync generator and fixing listing * fix tests * fix comments * first version of from_storage without deprecated listing * first version of from_storage without deprecated listing * fixing tests and removing prints, refactoring * fix listing generator output type * fix linter * fix docs * fixing test * fix list bucket args * refactoring listing static methods * fixing non recursive queries * refactoring * fixing listing generator tests * added partial test * using ctc in test session * moved listing functions to separated file * added listing unit tests * fixing json * fixing examples * fix file signal type from storage * fixing example * refactoring ls function * added more tests and fixed comments * fixing test * fix test name * fixing windows tests * returning to all tests * removed constants from dc.py * added ticket number * couple of fixes from PR review * added new method is_dataset_listing and assertions * refactoring listing code * added session on cloud test catalog and refactoring tests * added uses glob util * extracted partial with update to separate test * returning Column from db_signals method * import directly from datachain * changed boolean functions with prefix is_ * removed kwargs from from_storage * removed kwargs from datasets method * refactoring parsing listing dataset name * Update src/datachain/lib/file.py Co-authored-by: Ronan Lamy <[email protected]> * removed client config * removed kwargs from from_records * fixing comment * fixing new test * fixing listing --------- Co-authored-by: Ronan Lamy <[email protected]> Co-authored-by: ivan <[email protected]>
1 parent 12ddf7b commit bd52a96

22 files changed

+501
-113
lines changed

examples/get_started/udfs/parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def path_len_benchmark(path):
3131

3232
# Run in chain
3333
DataChain.from_storage(
34-
path="gs://datachain-demo/dogs-and-cats/",
34+
"gs://datachain-demo/dogs-and-cats/",
3535
).settings(parallel=-1).map(
3636
path_len_benchmark,
3737
params=["file.path"],

examples/get_started/udfs/simple.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def path_len(path):
1111
if __name__ == "__main__":
1212
# Run in chain
1313
DataChain.from_storage(
14-
path="gs://datachain-demo/dogs-and-cats/",
14+
uri="gs://datachain-demo/dogs-and-cats/",
1515
).map(
1616
path_len,
1717
params=["file.path"],

src/datachain/catalog/catalog.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1422,17 +1422,18 @@ def get_dataset_dependencies(
14221422

14231423
return direct_dependencies
14241424

1425-
def ls_datasets(self) -> Iterator[DatasetRecord]:
1425+
def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]:
14261426
datasets = self.metastore.list_datasets()
14271427
for d in datasets:
1428-
if not d.is_bucket_listing:
1428+
if not d.is_bucket_listing or include_listing:
14291429
yield d
14301430

14311431
def list_datasets_versions(
14321432
self,
1433+
include_listing: bool = False,
14331434
) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
14341435
"""Iterate over all dataset versions with related jobs."""
1435-
datasets = list(self.ls_datasets())
1436+
datasets = list(self.ls_datasets(include_listing=include_listing))
14361437

14371438
# preselect dataset versions jobs from db to avoid multiple queries
14381439
jobs_ids: set[str] = {

src/datachain/dataset.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
DATASET_PREFIX = "ds://"
2727
QUERY_DATASET_PREFIX = "ds_query_"
28+
LISTING_PREFIX = "lst__"
2829

2930

3031
def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]:
@@ -443,7 +444,11 @@ def is_bucket_listing(self) -> bool:
443444
For bucket listing we implicitly create underlying dataset to hold data. This
444445
method is checking if this is one of those datasets.
445446
"""
446-
return Client.is_data_source_uri(self.name)
447+
# TODO refactor and maybe remove method in
448+
# https://github.com/iterative/datachain/issues/318
449+
return Client.is_data_source_uri(self.name) or self.name.startswith(
450+
LISTING_PREFIX
451+
)
447452

448453
@property
449454
def versions_values(self) -> list[int]:

src/datachain/lib/dc.py

+77-25
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,15 @@
2727
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
2828
from datachain.lib.dataset_info import DatasetInfo
2929
from datachain.lib.file import ExportPlacement as FileExportPlacement
30-
from datachain.lib.file import File, IndexedFile, get_file
30+
from datachain.lib.file import File, IndexedFile, get_file_type
31+
from datachain.lib.listing import (
32+
is_listing_dataset,
33+
is_listing_expired,
34+
is_listing_subset,
35+
list_bucket,
36+
ls,
37+
parse_listing_uri,
38+
)
3139
from datachain.lib.meta_formats import read_meta, read_schema
3240
from datachain.lib.model_store import ModelStore
3341
from datachain.lib.settings import Settings
@@ -311,7 +319,7 @@ def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
311319
@classmethod
312320
def from_storage(
313321
cls,
314-
path,
322+
uri,
315323
*,
316324
type: Literal["binary", "text", "image"] = "binary",
317325
session: Optional[Session] = None,
@@ -320,41 +328,79 @@ def from_storage(
320328
recursive: Optional[bool] = True,
321329
object_name: str = "file",
322330
update: bool = False,
323-
**kwargs,
331+
anon: bool = False,
324332
) -> "Self":
325333
"""Get data from a storage as a list of file with all file attributes.
326334
It returns the chain itself as usual.
327335
328336
Parameters:
329-
path : storage URI with directory. URI must start with storage prefix such
337+
uri : storage URI with directory. URI must start with storage prefix such
330338
as `s3://`, `gs://`, `az://` or "file:///"
331339
type : read file as "binary", "text", or "image" data. Default is "binary".
332340
recursive : search recursively for the given path.
333341
object_name : Created object column name.
334342
update : force storage reindexing. Default is False.
343+
anon : If True, we will treat cloud bucket as public one
335344
336345
Example:
337346
```py
338347
chain = DataChain.from_storage("s3://my-bucket/my-dir")
339348
```
340349
"""
341-
func = get_file(type)
342-
return (
343-
cls(
344-
path,
345-
session=session,
346-
settings=settings,
347-
recursive=recursive,
348-
update=update,
349-
in_memory=in_memory,
350-
**kwargs,
351-
)
352-
.map(**{object_name: func})
353-
.select(object_name)
350+
file_type = get_file_type(type)
351+
352+
if anon:
353+
client_config = {"anon": True}
354+
else:
355+
client_config = None
356+
357+
session = Session.get(session, client_config=client_config, in_memory=in_memory)
358+
359+
list_dataset_name, list_uri, list_path = parse_listing_uri(
360+
uri, session.catalog.cache, session.catalog.client_config
354361
)
362+
need_listing = True
363+
364+
for ds in cls.datasets(
365+
session=session, in_memory=in_memory, include_listing=True
366+
).collect("dataset"):
367+
if (
368+
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
369+
and is_listing_dataset(ds.name) # type: ignore[union-attr]
370+
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
371+
and not update
372+
):
373+
need_listing = False
374+
list_dataset_name = ds.name # type: ignore[union-attr]
375+
376+
if need_listing:
377+
# caching new listing to special listing dataset
378+
(
379+
cls.from_records(
380+
DataChain.DEFAULT_FILE_RECORD,
381+
session=session,
382+
settings=settings,
383+
in_memory=in_memory,
384+
)
385+
.gen(
386+
list_bucket(list_uri, client_config=session.catalog.client_config),
387+
output={f"{object_name}": File},
388+
)
389+
.save(list_dataset_name, listing=True)
390+
)
391+
392+
dc = cls.from_dataset(list_dataset_name, session=session)
393+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
394+
395+
return ls(dc, list_path, recursive=recursive, object_name=object_name)
355396

356397
@classmethod
357-
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
398+
def from_dataset(
399+
cls,
400+
name: str,
401+
version: Optional[int] = None,
402+
session: Optional[Session] = None,
403+
) -> "DataChain":
358404
"""Get data from a saved Dataset. It returns the chain itself.
359405
360406
Parameters:
@@ -366,7 +412,7 @@ def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
366412
chain = DataChain.from_dataset("my_cats")
367413
```
368414
"""
369-
return DataChain(name=name, version=version)
415+
return DataChain(name=name, version=version, session=session)
370416

371417
@classmethod
372418
def from_json(
@@ -419,7 +465,7 @@ def jmespath_to_name(s: str):
419465
object_name = jmespath_to_name(jmespath)
420466
if not object_name:
421467
object_name = meta_type
422-
chain = DataChain.from_storage(path=path, type=type, **kwargs)
468+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
423469
signal_dict = {
424470
object_name: read_meta(
425471
schema_from=schema_from,
@@ -479,7 +525,7 @@ def jmespath_to_name(s: str):
479525
object_name = jmespath_to_name(jmespath)
480526
if not object_name:
481527
object_name = meta_type
482-
chain = DataChain.from_storage(path=path, type=type, **kwargs)
528+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
483529
signal_dict = {
484530
object_name: read_meta(
485531
schema_from=schema_from,
@@ -500,6 +546,7 @@ def datasets(
500546
settings: Optional[dict] = None,
501547
in_memory: bool = False,
502548
object_name: str = "dataset",
549+
include_listing: bool = False,
503550
) -> "DataChain":
504551
"""Generate chain with list of registered datasets.
505552
@@ -517,7 +564,9 @@ def datasets(
517564

518565
datasets = [
519566
DatasetInfo.from_models(d, v, j)
520-
for d, v, j in catalog.list_datasets_versions()
567+
for d, v, j in catalog.list_datasets_versions(
568+
include_listing=include_listing
569+
)
521570
]
522571

523572
return cls.from_values(
@@ -570,7 +619,7 @@ def print_jsonl_schema( # type: ignore[override]
570619
)
571620

572621
def save( # type: ignore[override]
573-
self, name: Optional[str] = None, version: Optional[int] = None
622+
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
574623
) -> "Self":
575624
"""Save to a Dataset. It returns the chain itself.
576625
@@ -580,7 +629,7 @@ def save( # type: ignore[override]
580629
version : version of a dataset. Default - the last version that exist.
581630
"""
582631
schema = self.signals_schema.clone_without_sys_signals().serialize()
583-
return super().save(name=name, version=version, feature_schema=schema)
632+
return super().save(name=name, version=version, feature_schema=schema, **kwargs)
584633

585634
def apply(self, func, *args, **kwargs):
586635
"""Apply any function to the chain.
@@ -1665,7 +1714,10 @@ def from_records(
16651714

16661715
if schema:
16671716
signal_schema = SignalSchema(schema)
1668-
columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
1717+
columns = [
1718+
sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
1719+
for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
1720+
]
16691721
else:
16701722
columns = [
16711723
sqlalchemy.Column(name, typ)

src/datachain/lib/file.py

+10-33
Original file line numberDiff line numberDiff line change
@@ -349,39 +349,6 @@ def save(self, destination: str):
349349
self.read().save(destination)
350350

351351

352-
def get_file(type_: Literal["binary", "text", "image"] = "binary"):
353-
file: type[File] = File
354-
if type_ == "text":
355-
file = TextFile
356-
elif type_ == "image":
357-
file = ImageFile # type: ignore[assignment]
358-
359-
def get_file_type(
360-
source: str,
361-
path: str,
362-
size: int,
363-
version: str,
364-
etag: str,
365-
is_latest: bool,
366-
last_modified: datetime,
367-
location: Optional[Union[dict, list[dict]]],
368-
vtype: str,
369-
) -> file: # type: ignore[valid-type]
370-
return file(
371-
source=source,
372-
path=path,
373-
size=size,
374-
version=version,
375-
etag=etag,
376-
is_latest=is_latest,
377-
last_modified=last_modified,
378-
location=location,
379-
vtype=vtype,
380-
)
381-
382-
return get_file_type
383-
384-
385352
class IndexedFile(DataModel):
386353
"""Metadata indexed from tabular files.
387354
@@ -390,3 +357,13 @@ class IndexedFile(DataModel):
390357

391358
file: File
392359
index: int
360+
361+
362+
def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
363+
file: type[File] = File
364+
if type_ == "text":
365+
file = TextFile
366+
elif type_ == "image":
367+
file = ImageFile # type: ignore[assignment]
368+
369+
return file

0 commit comments

Comments
 (0)