Skip to content

Commit 57cb8f3

Browse files
committed
doc: improve datastream and merge description
1 parent 282da17 commit 57cb8f3

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

datastream/dataset.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Dataset(BaseModel, Generic[T]):
3535
... cost * 2,
3636
... ))
3737
... )
38-
>>> print(dataset[2])
38+
>>> dataset[2]
3939
('banana', 28)
4040
'''
4141

@@ -51,10 +51,11 @@ class Config:
5151
def from_subscriptable(subscriptable) -> Dataset:
5252
'''
5353
Create ``Dataset`` based on subscriptable i.e. implements
54-
``__getitem__`` and ``__len__``. Should only be used for simple
55-
examples as a ``Dataset`` created with this method does not support
56-
methods that require a source dataframe (i.e. :func:`Dataset.split`
57-
and :func:`Dataset.subset`)
54+
``__getitem__`` and ``__len__``.
55+
56+
Should only be used for simple examples as a ``Dataset`` created with
57+
this method does not support methods that require a source dataframe
58+
like :func:`Dataset.split` and :func:`Dataset.subset`.
5859
'''
5960

6061
return (
@@ -328,7 +329,6 @@ def group_split(
328329
).items()
329330
}
330331

331-
332332
def with_columns(
333333
self: Dataset[T], **kwargs: Callable[pd.Dataframe, pd.Series]
334334
) -> Dataset[T]:
@@ -405,8 +405,11 @@ def to_concat(dataset_index, inner_index):
405405
def concat(datasets: List[Dataset]) -> Dataset[R]:
406406
'''
407407
Concatenate multiple datasets together so that they behave like a
408-
single dataset. Consider using :func:`Datastream.merge` if you have
409-
multiple data sources.
408+
single dataset.
409+
410+
Consider using :func:`Datastream.merge` if you have
411+
multiple data sources instead as it allows you to control the number
412+
of samples from each source in the training batches.
410413
'''
411414
from_concat_mapping = Dataset.create_from_concat_mapping(datasets)
412415

@@ -440,6 +443,7 @@ def from_combine(index):
440443
@staticmethod
441444
def create_to_combine_mapping(datasets):
442445
cumprod_lengths = np.cumprod(list(map(len, datasets)))
446+
443447
def to_concat(inner_indices):
444448
return inner_indices[0] + sum([
445449
inner_index * cumprod_lengths[i]
@@ -453,7 +457,7 @@ def combine(datasets: List[Dataset]) -> Dataset[Tuple]:
453457
Zip multiple datasets together so that all combinations of examples
454458
are possible (i.e. the product) creating tuples like
455459
``(example1, example2, ...)``.
456-
460+
457461
The created dataset will not have a dataframe because combined
458462
datasets are often very long and it is expensive to enumerate them.
459463
'''

datastream/datastream.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
class Datastream(BaseModel, Generic[T]):
3232
'''
3333
``Datastream[T]`` combines a ``Dataset[T]`` and a sampler into a stream of
34-
examples. By default the samples are drawn without replacement until the
34+
examples.
35+
36+
By default the samples are drawn without replacement until the
3537
full dataset is exhausted. The proportion of the dataset that should be
3638
drawn before allowing replacement can be changed with
3739
:func:`Datastream.sample_proportion`.
@@ -70,7 +72,7 @@ def __init__(
7072

7173
def __len__(self):
7274
return len(self.sampler)
73-
75+
7476
def __iter__(self):
7577
return map(self.dataset.__getitem__, iter(self.sampler))
7678

@@ -80,17 +82,22 @@ def merge(datastreams_and_ns: Tuple[Union[
8082
Tuple[Datastream[T], int]
8183
], ...]) -> Datastream[T]:
8284
'''
83-
Merge multiple datastreams by interleaving them. Optionally you can
84-
define different lengths per ``Datastream``.
85-
86-
.. highlight:: python
87-
.. code-block:: python
88-
89-
Datastream.merge([
90-
(datastream1, 2),
91-
(datastream2, 1),
92-
(datastream3, 1),
93-
])
85+
Creates a merged datastream where samples are drawn one at a time from
86+
each underlying datastream (also known as "interleave").
87+
88+
Optionally you can define the number of drawn samples per
89+
``Datastream``.
90+
91+
>>> datastream1 = Datastream(Dataset.from_subscriptable([1, 1]))
92+
>>> datastream2 = Datastream(Dataset.from_subscriptable([2, 2]))
93+
>>> datastream3 = Datastream(Dataset.from_subscriptable([3, 3, 3, 3]))
94+
>>> merged_datastream = Datastream.merge([
95+
... (datastream1, 1),
96+
... (datastream2, 1),
97+
... (datastream3, 2),
98+
... ])
99+
>>> list(merged_datastream)
100+
[1, 2, 3, 3, 1, 2, 3, 3]
94101
'''
95102
datastreams_and_ns = [
96103
x if type(x) is tuple else (x, 1)

0 commit comments

Comments
 (0)