Skip to content

Commit 5f6bc96

Browse files
committed
Utilize shared finalizer for MemoryView
1 parent d0010fd commit 5f6bc96

File tree

7 files changed

+40
-95
lines changed

7 files changed

+40
-95
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
import com.oracle.graal.python.builtins.PythonBuiltins;
111111
import com.oracle.graal.python.builtins.modules.WarningsModuleBuiltins.WarnNode;
112112
import com.oracle.graal.python.builtins.modules.WeakRefModuleBuiltins.GetWeakRefsNode;
113-
import com.oracle.graal.python.builtins.objects.ellipsis.PEllipsis;
114113
import com.oracle.graal.python.builtins.objects.PNone;
115114
import com.oracle.graal.python.builtins.objects.PNotImplemented;
116115
import com.oracle.graal.python.builtins.objects.array.PArray;
@@ -139,6 +138,7 @@
139138
import com.oracle.graal.python.builtins.objects.complex.PComplex;
140139
import com.oracle.graal.python.builtins.objects.dict.DictBuiltins;
141140
import com.oracle.graal.python.builtins.objects.dict.PDict;
141+
import com.oracle.graal.python.builtins.objects.ellipsis.PEllipsis;
142142
import com.oracle.graal.python.builtins.objects.enumerate.PEnumerate;
143143
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltins;
144144
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltinsFactory;
@@ -156,7 +156,6 @@
156156
import com.oracle.graal.python.builtins.objects.iterator.PZip;
157157
import com.oracle.graal.python.builtins.objects.list.PList;
158158
import com.oracle.graal.python.builtins.objects.map.PMap;
159-
import com.oracle.graal.python.builtins.objects.memoryview.MemoryViewNodes;
160159
import com.oracle.graal.python.builtins.objects.memoryview.PBuffer;
161160
import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView;
162161
import com.oracle.graal.python.builtins.objects.module.PythonModule;
@@ -245,6 +244,7 @@
245244
import com.oracle.truffle.api.Truffle;
246245
import com.oracle.truffle.api.dsl.Cached;
247246
import com.oracle.truffle.api.dsl.Cached.Shared;
247+
import com.oracle.truffle.api.dsl.CachedContext;
248248
import com.oracle.truffle.api.dsl.Fallback;
249249
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
250250
import com.oracle.truffle.api.dsl.ReportPolymorphism;
@@ -3415,9 +3415,9 @@ PMemoryView fromArray(@SuppressWarnings("unused") Object cls, PArray object) {
34153415

34163416
@Specialization
34173417
PMemoryView fromMemoryView(@SuppressWarnings("unused") Object cls, PMemoryView object,
3418-
@Cached MemoryViewNodes.GetBufferReferences getQueue) {
3418+
@Shared("c") @CachedContext(PythonLanguage.class) PythonContext context) {
34193419
object.checkReleased(this);
3420-
return factory().createMemoryView(getQueue.execute(), object.getManagedBuffer(), object.getOwner(), object.getLength(),
3420+
return factory().createMemoryView(context, object.getManagedBuffer(), object.getOwner(), object.getLength(),
34213421
object.isReadOnly(), object.getItemSize(), object.getFormat(), object.getFormatString(), object.getDimensions(),
34223422
object.getBufferPointer(), object.getOffset(), object.getBufferShape(), object.getBufferStrides(),
34233423
object.getBufferSuboffsets(), object.getFlags());

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ Object wrap(VirtualFrame frame, Object bufferStructPointer, Object ownerObj, Obj
15851585
@Cached AsPythonObjectNode asPythonObjectNode,
15861586
@Cached ToNewRefNode toNewRefNode,
15871587
@Cached GetNativeNullNode getNativeNullNode,
1588-
@Cached MemoryViewNodes.GetBufferReferences getQueue) {
1588+
@CachedContext(PythonLanguage.class) PythonContext context) {
15891589
try {
15901590
int ndim = castToIntNode.execute(ndimObj);
15911591
int itemsize = castToIntNode.execute(itemsizeObj);
@@ -1627,7 +1627,7 @@ Object wrap(VirtualFrame frame, Object bufferStructPointer, Object ownerObj, Obj
16271627
if (!lib.isNull(bufferStructPointer)) {
16281628
managedBuffer = new ManagedBuffer(bufferStructPointer);
16291629
}
1630-
PMemoryView memoryview = factory().createMemoryView(getQueue.execute(), managedBuffer, owner, len, readonly, itemsize,
1630+
PMemoryView memoryview = factory().createMemoryView(context, managedBuffer, owner, len, readonly, itemsize,
16311631
BufferFormat.forMemoryView(format),
16321632
format, ndim, bufPointer, 0, shape, strides, suboffsets, flags);
16331633
return toNewRefNode.execute(memoryview);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/BufferReference.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,26 @@
4040
*/
4141
package com.oracle.graal.python.builtins.objects.memoryview;
4242

43-
import java.lang.ref.PhantomReference;
44-
import java.lang.ref.ReferenceQueue;
43+
import com.oracle.graal.python.runtime.AsyncHandler;
4544

46-
class BufferReference extends PhantomReference<PMemoryView> {
47-
private final ManagedBuffer managedBuffer;
48-
private boolean released;
45+
class BufferReference extends AsyncHandler.SharedFinalizer.FinalizableReference {
4946

50-
public BufferReference(PMemoryView referent, ManagedBuffer managedBuffer, ReferenceQueue<PMemoryView> q) {
51-
super(referent, q);
47+
public BufferReference(PMemoryView referent, ManagedBuffer managedBuffer, AsyncHandler.SharedFinalizer sharedFinalizer) {
48+
super(referent, managedBuffer, sharedFinalizer);
5249
assert managedBuffer != null;
5350
managedBuffer.incrementExports();
54-
this.managedBuffer = managedBuffer;
5551
}
5652

5753
public ManagedBuffer getManagedBuffer() {
58-
return managedBuffer;
54+
return (ManagedBuffer) getReference();
5955
}
6056

61-
public boolean isReleased() {
62-
return released;
63-
}
64-
65-
public void markReleased() {
66-
this.released = true;
57+
@Override
58+
public AsyncHandler.AsyncAction release() {
59+
ManagedBuffer buffer = getManagedBuffer();
60+
if (buffer.decrementExports() == 0) {
61+
return new MemoryViewBuiltins.NativeBufferReleaseCallback(this);
62+
}
63+
return null;
6764
}
6865
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewBuiltins.java

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@
5454
import static com.oracle.graal.python.nodes.SpecialMethodNames.__REPR__;
5555
import static com.oracle.graal.python.nodes.SpecialMethodNames.__SETITEM__;
5656

57-
import java.lang.ref.Reference;
5857
import java.util.Arrays;
5958
import java.util.List;
6059

60+
import com.oracle.graal.python.PythonLanguage;
6161
import com.oracle.graal.python.annotations.ArgumentClinic;
6262
import com.oracle.graal.python.builtins.Builtin;
6363
import com.oracle.graal.python.builtins.CoreFunctions;
@@ -91,7 +91,6 @@
9191
import com.oracle.graal.python.runtime.AsyncHandler;
9292
import com.oracle.graal.python.runtime.ExecutionContext;
9393
import com.oracle.graal.python.runtime.PythonContext;
94-
import com.oracle.graal.python.runtime.PythonCore;
9594
import com.oracle.graal.python.runtime.exception.PException;
9695
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
9796
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
@@ -101,6 +100,7 @@
101100
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
102101
import com.oracle.truffle.api.dsl.Cached;
103102
import com.oracle.truffle.api.dsl.Cached.Shared;
103+
import com.oracle.truffle.api.dsl.CachedContext;
104104
import com.oracle.truffle.api.dsl.Fallback;
105105
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
106106
import com.oracle.truffle.api.dsl.NodeFactory;
@@ -146,33 +146,6 @@ public void execute(PythonContext context) {
146146
}
147147
}
148148

149-
@Override
150-
public void postInitialize(PythonCore core) {
151-
super.postInitialize(core);
152-
MemoryViewNodes.BufferReferences bufferReferences = new MemoryViewNodes.BufferReferences();
153-
core.lookupType(PythonBuiltinClassType.PMemoryView).setAttribute(bufferReferencesKey, bufferReferences);
154-
core.getContext().registerAsyncAction(() -> {
155-
Reference<? extends PMemoryView> reference = null;
156-
try {
157-
reference = bufferReferences.queue.remove();
158-
} catch (InterruptedException e) {
159-
Thread.currentThread().interrupt();
160-
}
161-
if (reference instanceof BufferReference) {
162-
BufferReference bufferReference = (BufferReference) reference;
163-
bufferReferences.set.remove(bufferReference);
164-
if (bufferReference.isReleased()) {
165-
return null;
166-
}
167-
ManagedBuffer buffer = bufferReference.getManagedBuffer();
168-
if (buffer.decrementExports() == 0) {
169-
return new NativeBufferReleaseCallback(bufferReference);
170-
}
171-
}
172-
return null;
173-
});
174-
}
175-
176149
@Builtin(name = __GETITEM__, minNumOfPositionalArgs = 2)
177150
@GenerateNodeFactory
178151
abstract static class GetItemNode extends PythonBinaryBuiltinNode {
@@ -192,7 +165,7 @@ Object getitemSlice(PMemoryView self, PSlice slice,
192165
@Cached SliceLiteralNode.SliceUnpack sliceUnpack,
193166
@Cached SliceLiteralNode.AdjustIndices adjustIndices,
194167
@Cached MemoryViewNodes.InitFlagsNode initFlagsNode,
195-
@Cached MemoryViewNodes.GetBufferReferences getQueue) {
168+
@CachedContext(PythonLanguage.class) PythonContext context) {
196169
self.checkReleased(this);
197170
if (zeroDimProfile.profile(self.getDimensions() == 0)) {
198171
throw raise(TypeError, ErrorMessages.INVALID_INDEXING_OF_0_DIM_MEMORY);
@@ -209,7 +182,7 @@ Object getitemSlice(PMemoryView self, PSlice slice,
209182
int[] suboffsets = self.getBufferSuboffsets();
210183
int length = self.getLength() - (shape[0] - newShape[0]) * self.getItemSize();
211184
int flags = initFlagsNode.execute(self.getDimensions(), self.getItemSize(), newShape, newStrides, suboffsets);
212-
return factory().createMemoryView(getQueue.execute(), self.getManagedBuffer(), self.getOwner(), length, self.isReadOnly(),
185+
return factory().createMemoryView(context, self.getManagedBuffer(), self.getOwner(), length, self.isReadOnly(),
213186
self.getItemSize(), self.getFormat(), self.getFormatString(), self.getDimensions(), self.getBufferPointer(),
214187
self.getOffset() + sliceInfo.start * strides[0], newShape, newStrides, suboffsets, flags);
215188
}
@@ -565,9 +538,9 @@ protected ArgumentClinicProvider getArgumentClinic() {
565538
public abstract static class ToReadonlyNode extends PythonUnaryBuiltinNode {
566539
@Specialization
567540
PMemoryView toreadonly(PMemoryView self,
568-
@Cached MemoryViewNodes.GetBufferReferences getQueue) {
541+
@CachedContext(PythonLanguage.class) PythonContext context) {
569542
self.checkReleased(this);
570-
return factory().createMemoryView(getQueue.execute(), self.getManagedBuffer(), self.getOwner(), self.getLength(), true,
543+
return factory().createMemoryView(context, self.getManagedBuffer(), self.getOwner(), self.getLength(), true,
571544
self.getItemSize(), self.getFormat(), self.getFormatString(), self.getDimensions(), self.getBufferPointer(),
572545
self.getOffset(), self.getBufferShape(), self.getBufferStrides(), self.getBufferSuboffsets(), self.getFlags());
573546
}
@@ -580,14 +553,14 @@ public abstract static class CastNode extends PythonTernaryClinicBuiltinNode {
580553

581554
@Specialization
582555
PMemoryView cast(PMemoryView self, String formatString, @SuppressWarnings("unused") PNone none,
583-
@Shared("getQueue") @Cached MemoryViewNodes.GetBufferReferences getQueue) {
556+
@Shared("c") @CachedContext(PythonLanguage.class) PythonContext context) {
584557
self.checkReleased(this);
585-
return doCast(self, formatString, 1, null, getQueue.execute());
558+
return doCast(self, formatString, 1, null, context);
586559
}
587560

588561
@Specialization(guards = "isPTuple(shapeObj) || isList(shapeObj)")
589562
PMemoryView cast(PMemoryView self, String formatString, Object shapeObj,
590-
@Shared("getQueue") @Cached MemoryViewNodes.GetBufferReferences getQueue,
563+
@Shared("c") @CachedContext(PythonLanguage.class) PythonContext context,
591564
@Cached SequenceNodes.GetSequenceStorageNode getSequenceStorageNode,
592565
@Cached SequenceStorageNodes.LenNode lenNode,
593566
@Cached SequenceStorageNodes.GetItemScalarNode getItemScalarNode,
@@ -602,7 +575,7 @@ PMemoryView cast(PMemoryView self, String formatString, Object shapeObj,
602575
throw raise(TypeError, ErrorMessages.MEMORYVIEW_CAST_ELEMENTS_MUST_BE_POSITIVE_INTEGERS);
603576
}
604577
}
605-
return doCast(self, formatString, ndim, shape, getQueue.execute());
578+
return doCast(self, formatString, ndim, shape, context);
606579
}
607580

608581
@Specialization(guards = {"!isPTuple(shape)", "!isList(shape)", "!isPNone(shape)"})
@@ -611,7 +584,7 @@ PMemoryView error(PMemoryView self, String format, Object shape) {
611584
throw raise(TypeError, ErrorMessages.ARG_S_MUST_BE_A_LIST_OR_TUPLE, "shape");
612585
}
613586

614-
private PMemoryView doCast(PMemoryView self, String formatString, int ndim, int[] shape, MemoryViewNodes.BufferReferences refQueue) {
587+
private PMemoryView doCast(PMemoryView self, String formatString, int ndim, int[] shape, PythonContext context) {
615588
if (!self.isCContiguous()) {
616589
throw raise(TypeError, ErrorMessages.MEMORYVIEW_CASTS_RESTRICTED_TO_C_CONTIGUOUS);
617590
}
@@ -664,7 +637,7 @@ private PMemoryView doCast(PMemoryView self, String formatString, int ndim, int[
664637
}
665638
newStrides = PMemoryView.initStridesFromShape(ndim, itemsize, shape);
666639
}
667-
return factory().createMemoryView(refQueue, self.getManagedBuffer(), self.getOwner(), self.getLength(), self.isReadOnly(),
640+
return factory().createMemoryView(context, self.getManagedBuffer(), self.getOwner(), self.getLength(), self.isReadOnly(),
668641
itemsize, format, formatString, ndim, self.getBufferPointer(),
669642
self.getOffset(), newShape, newStrides, null, flags);
670643
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewNodes.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@
4646
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.TypeError;
4747
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.ValueError;
4848

49-
import java.lang.ref.ReferenceQueue;
50-
import java.util.HashSet;
51-
import java.util.Set;
52-
53-
import com.oracle.graal.python.PythonLanguage;
54-
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5549
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes;
5650
import com.oracle.graal.python.builtins.objects.cext.capi.NativeCAPISymbols;
5751
import com.oracle.graal.python.builtins.objects.common.BufferStorageNodes;
@@ -63,9 +57,7 @@
6357
import com.oracle.graal.python.nodes.PGuards;
6458
import com.oracle.graal.python.nodes.PNodeWithRaise;
6559
import com.oracle.graal.python.nodes.PRaiseNode;
66-
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
6760
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
68-
import com.oracle.graal.python.runtime.PythonContext;
6961
import com.oracle.graal.python.runtime.exception.PException;
7062
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
7163
import com.oracle.graal.python.util.BufferFormat;
@@ -74,7 +66,6 @@
7466
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
7567
import com.oracle.truffle.api.CompilerDirectives.ValueType;
7668
import com.oracle.truffle.api.dsl.Cached;
77-
import com.oracle.truffle.api.dsl.CachedContext;
7869
import com.oracle.truffle.api.dsl.Fallback;
7970
import com.oracle.truffle.api.dsl.GenerateUncached;
8071
import com.oracle.truffle.api.dsl.ImportStatic;
@@ -642,20 +633,4 @@ public static ToJavaBytesFortranOrderNode create() {
642633
return MemoryViewNodesFactory.ToJavaBytesFortranOrderNodeGen.create();
643634
}
644635
}
645-
646-
public abstract static class GetBufferReferences extends Node {
647-
public abstract BufferReferences execute();
648-
649-
@Specialization
650-
@SuppressWarnings("unchecked")
651-
static BufferReferences getRefs(@CachedContext(PythonLanguage.class) PythonContext context,
652-
@Cached ReadAttributeFromObjectNode readNode) {
653-
return (BufferReferences) readNode.execute(context.getCore().lookupType(PythonBuiltinClassType.PMemoryView), MemoryViewBuiltins.bufferReferencesKey);
654-
}
655-
}
656-
657-
public static class BufferReferences {
658-
public final ReferenceQueue<PMemoryView> queue = new ReferenceQueue<>();
659-
public final Set<BufferReference> set = new HashSet<>();
660-
}
661636
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/PMemoryView.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import com.oracle.graal.python.nodes.PNodeWithRaise;
5151
import com.oracle.graal.python.nodes.PRaiseNode;
5252
import com.oracle.graal.python.util.BufferFormat;
53+
import com.oracle.graal.python.runtime.AsyncHandler;
54+
import com.oracle.graal.python.runtime.PythonContext;
5355
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
5456
import com.oracle.truffle.api.dsl.Cached;
5557
import com.oracle.truffle.api.library.ExportLibrary;
@@ -91,7 +93,7 @@ public final class PMemoryView extends PythonBuiltinObject {
9193
// Cached hash value, required to compy with CPython's semantics
9294
private int cachedHash = -1;
9395

94-
public PMemoryView(Object cls, Shape instanceShape, MemoryViewNodes.BufferReferences references, ManagedBuffer managedBuffer, Object owner,
96+
public PMemoryView(Object cls, Shape instanceShape, PythonContext context, ManagedBuffer managedBuffer, Object owner,
9597
int len, boolean readonly, int itemsize, BufferFormat format, String formatString, int ndim, Object bufPointer,
9698
int offset, int[] shape, int[] strides, int[] suboffsets, int flags) {
9799
super(cls, instanceShape);
@@ -109,14 +111,13 @@ public PMemoryView(Object cls, Shape instanceShape, MemoryViewNodes.BufferRefere
109111
this.suboffsets = suboffsets;
110112
this.flags = flags;
111113
if (managedBuffer != null) {
112-
createReference(references, managedBuffer);
114+
createReference(context.getSharedFinalizer(), managedBuffer);
113115
}
114116
}
115117

116118
@TruffleBoundary
117-
private void createReference(MemoryViewNodes.BufferReferences references, ManagedBuffer managedBuffer) {
118-
this.reference = new BufferReference(this, managedBuffer, references.queue);
119-
references.set.add(this.reference);
119+
private void createReference(AsyncHandler.SharedFinalizer sharedFinalizer, ManagedBuffer managedBuffer) {
120+
this.reference = new BufferReference(this, managedBuffer, sharedFinalizer);
120121
}
121122

122123
// From CPython init_strides_from_shape

0 commit comments

Comments
 (0)