6
6
from mlir .ir import *
7
7
import numpy as np
8
8
import weakref
9
+ import ctypes
9
10
10
11
11
12
def run (f ):
@@ -617,3 +618,119 @@ def test_attribute(context, mview):
617
618
# CHECK: BACKING MEMORY DELETED
618
619
# CHECK: EXIT FUNCTION
619
620
print ("EXIT FUNCTION" )
621
+
622
+
623
+ # CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayI32
624
+ @run
625
+ def testGetDenseResourceElementsAttrNdarrayI32 ():
626
+ class DLPackWrapper :
627
+ def __init__ (self , array : np .ndarray ):
628
+ self .dlpack_capsule = array .__dlpack__ ()
629
+
630
+ def __del__ (self ):
631
+ print ("BACKING MEMORY DELETED" )
632
+
633
+ def get_capsule (self ):
634
+ return self .dlpack_capsule
635
+
636
+ context = Context ()
637
+ mview_int32 = DLPackWrapper (np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .int32 ))
638
+
639
+ def test_attribute_int32 (context , mview_int32 ):
640
+ with context , Location .unknown ():
641
+ element_type = IntegerType .get_signless (32 )
642
+ tensor_type = RankedTensorType .get ((2 , 3 ), element_type )
643
+ resource = DenseResourceElementsAttr .get_from_ndarray (
644
+ mview_int32 .get_capsule (), "from_py" , tensor_type
645
+ )
646
+ module = Module .parse ("module {}" )
647
+ module .operation .attributes ["test.resource" ] = resource
648
+ # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
649
+ # CHECK: from_py: "0x01000000010000000200"
650
+ print (module )
651
+
652
+ # Verifies type casting.
653
+ # CHECK: dense_resource<from_py> : tensor<2x3xi32>
654
+ print (
655
+ DenseResourceElementsAttr (module .operation .attributes ["test.resource" ])
656
+ )
657
+
658
+ test_attribute_int32 (context , mview_int32 )
659
+ del mview_int32
660
+ gc .collect ()
661
+ # CHECK: FREEING CONTEXT
662
+ print ("FREEING CONTEXT" )
663
+ context = None
664
+ gc .collect ()
665
+ # CHECK: EXIT FUNCTION
666
+ print ("EXIT FUNCTION" )
667
+
668
+
669
+ # CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayF32
670
+ @run
671
+ def testGetDenseResourceElementsAttrNdarrayF32 ():
672
+ class DLPackWrapper :
673
+ def __init__ (self , array : np .ndarray ):
674
+ self .dlpack_capsule = array .__dlpack__ ()
675
+
676
+ def __del__ (self ):
677
+ print ("BACKING MEMORY DELETED" )
678
+
679
+ def get_capsule (self ):
680
+ return self .dlpack_capsule
681
+
682
+ context = Context ()
683
+ mview_float32 = DLPackWrapper (np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float32 ))
684
+
685
+ def test_attribute_float32 (context , mview_float32 ):
686
+ with context , Location .unknown ():
687
+ element_type = FloatAttr .get_f32 (32.0 )
688
+ tensor_type = RankedTensorType .get ((2 , 3 ), element_type .type )
689
+ resource = DenseResourceElementsAttr .get_from_ndarray (
690
+ mview_float32 .get_capsule (), "from_py" , tensor_type
691
+ )
692
+ module = Module .parse ("module {}" )
693
+ module .operation .attributes ["test.resource" ] = resource
694
+ # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xf32>
695
+ # CHECK: from_py: "0x010000000000803F0000"
696
+ print (module )
697
+
698
+ # Verifies type casting.
699
+ # CHECK: dense_resource<from_py> : tensor<2x3xf32>
700
+ print (
701
+ DenseResourceElementsAttr (module .operation .attributes ["test.resource" ])
702
+ )
703
+
704
+ test_attribute_float32 (context , mview_float32 )
705
+ del mview_float32
706
+ gc .collect ()
707
+ # CHECK: FREEING CONTEXT
708
+ print ("FREEING CONTEXT" )
709
+ context = None
710
+ gc .collect ()
711
+ # CHECK: EXIT FUNCTION
712
+ print ("EXIT FUNCTION" )
713
+
714
+
715
+ # CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNonShapedType
716
+ @run
717
+ def testGetDenseResourceElementsAttrNonShapedType ():
718
+ class DLPackWrapper :
719
+ def __init__ (self , array : np .ndarray ):
720
+ self .dlpack_capsule = array .__dlpack__ ()
721
+
722
+ def __del__ (self ):
723
+ print ("BACKING MEMORY DELETED" )
724
+
725
+ def get_capsule (self ):
726
+ return self .dlpack_capsule
727
+
728
+ with Context (), Location .unknown ():
729
+ mview = DLPackWrapper (np .array ([1 ], dtype = np .int32 ))
730
+ t = F32Type .get ()
731
+
732
+ try :
733
+ attr = DenseResourceElementsAttr .get_from_ndarray (mview .get_capsule (), "from_py" , t )
734
+ except ValueError as e :
735
+ # CHECK: Constructing a DenseResourceElementsAttr requires a ShapedType.
736
+ print (e )
0 commit comments