Skip to content

Commit 282da17

Browse files
committed
fix: incorrect merge sampler length
1 parent 9b1589f commit 282da17

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

datastream/samplers/merge_sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, samplers, datasets, ns):
2626
samplers=samplers,
2727
datasets=datasets,
2828
ns=ns,
29-
length=MergeSampler.merged_samplers_length(samplers),
29+
length=MergeSampler.merged_samplers_length(samplers, ns),
3030
from_mapping=Dataset.create_from_concat_mapping(datasets),
3131
merged_samplers=MergeSampler.merge_samplers(
3232
samplers, datasets, ns
@@ -40,8 +40,11 @@ def __iter__(self):
4040
return islice(self.merged_samplers, self.length)
4141

4242
@staticmethod
43-
def merged_samplers_length(samplers):
44-
return max([len(sampler) for sampler in samplers])
43+
def merged_samplers_length(samplers, ns):
44+
return (
45+
min([len(sampler) / n for sampler, n in zip(samplers, ns)])
46+
* sum(ns)
47+
)
4548

4649
@staticmethod
4750
def merge_samplers(samplers, datasets, ns):

0 commit comments

Comments
 (0)