Skip to content

Commit be621e9

Browse files
committed
fix tee pickling
1 parent 3608635 commit be621e9

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

graalpython/lib-graalpython/itertools.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ def __reduce__(self):
10351035

10361036

10371037
class _tee_dataobject:
1038-
LINKCELLS = 32
1038+
LINKCELLS = 128
10391039

10401040
@__graalpython__.builtin_method
10411041
def __init__(self, it, values=None, nxt=None):
@@ -1045,12 +1045,14 @@ def __init__(self, it, values=None, nxt=None):
10451045
self.numread = len(values)
10461046
if self.numread == _tee_dataobject.LINKCELLS:
10471047
self.nextlink = nxt
1048+
if nxt and not isinstance(nxt, _tee_dataobject):
1049+
raise ValueError("_tee_dataobject next link must be a _tee_dataobject")
10481050
elif self.numread > _tee_dataobject.LINKCELLS:
10491051
raise ValueError(f"_tee_dataobject should nove have more than {_tee_dataobject.LINKCELLS} links")
10501052
elif nxt is not None:
10511053
raise ValueError("_tee_dataobject shouldn't have a next if not full")
10521054
else:
1053-
self.values = [None] * _tee_dataobject.LINKCELLS
1055+
self.values = []
10541056
self.numread = 0
10551057
self.running = False
10561058
self.nextlink = nxt
@@ -1062,10 +1064,12 @@ def _jumplink(self):
10621064
return self.nextlink
10631065

10641066
@__graalpython__.builtin_method
1065-
def __getitem__(self, i):
1067+
def _getitem(self, i):
1068+
assert i < _tee_dataobject.LINKCELLS
10661069
if i < self.numread:
10671070
return self.values[i]
10681071
else:
1072+
assert i == self.numread
10691073
if self.running:
10701074
raise RuntimeError("cannot re-enter the tee iterator")
10711075
self.running = True
@@ -1074,12 +1078,13 @@ def __getitem__(self, i):
10741078
finally:
10751079
self.running = False
10761080
self.numread += 1
1077-
self.values[i] = value
1081+
self.values.append(value)
10781082
return value
10791083

10801084
@__graalpython__.builtin_method
10811085
def __reduce__(self):
1082-
return type(self), (self.it, self.values, self.nextlink)
1086+
values = self.values[:self.numread]
1087+
return type(self), (self.it, values, self.nextlink)
10831088

10841089

10851090
class _tee:
@@ -1108,7 +1113,7 @@ def __next__(self):
11081113
if self.index >= _tee_dataobject.LINKCELLS:
11091114
self.dataobj = self.dataobj._jumplink()
11101115
self.index = 0
1111-
value = self.dataobj[self.index]
1116+
value = self.dataobj._getitem(self.index)
11121117
self.index += 1
11131118
return value
11141119

0 commit comments

Comments
 (0)