5
5
6
6
import numpy
7
7
from loopy .types import OpaqueType
8
+ from pyop2 .global_kernel import (GlobalKernelArg , DatKernelArg , MixedDatKernelArg ,
9
+ MatKernelArg , MixedMatKernelArg , PermutedMapKernelArg )
8
10
from pyop2 .codegen .representation import (Accumulate , Argument , Comparison ,
9
11
DummyInstruction , Extent , FixedIndex ,
10
12
FunctionCall , Index , Indexed ,
16
18
When , Zero )
17
19
from pyop2 .datatypes import IntType
18
20
from pyop2 .op2 import (ALL , INC , MAX , MIN , ON_BOTTOM , ON_INTERIOR_FACETS ,
19
- ON_TOP , READ , RW , WRITE , Subset , PermutedMap )
21
+ ON_TOP , READ , RW , WRITE )
20
22
from pyop2 .utils import cached_property
21
23
22
24
@@ -32,18 +34,22 @@ class Map(object):
32
34
"variable" , "unroll" , "layer_bounds" ,
33
35
"prefetch" , "_pmap_count" )
34
36
35
- def __init__ (self , map_ , interior_horizontal , layer_bounds ,
36
- offset = None , unroll = False ):
37
- self .variable = map_ .iterset ._extruded and not map_ .iterset .constant_layers
37
+ def __init__ (self , interior_horizontal , layer_bounds ,
38
+ arity , dtype ,
39
+ offset = None , unroll = False ,
40
+ extruded = False , constant_layers = False ):
41
+ self .variable = extruded and not constant_layers
38
42
self .unroll = unroll
39
43
self .layer_bounds = layer_bounds
40
44
self .interior_horizontal = interior_horizontal
41
45
self .prefetch = {}
42
- offset = map_ . offset
43
- shape = (None , ) + map_ . shape [ 1 :]
44
- values = Argument (shape , dtype = map_ . dtype , pfx = "map" )
46
+
47
+ shape = (None , arity )
48
+ values = Argument (shape , dtype = dtype , pfx = "map" )
45
49
if offset is not None :
46
- if len (set (map_ .offset )) == 1 :
50
+ assert type (offset ) == tuple
51
+ offset = numpy .array (offset , dtype = numpy .int32 )
52
+ if len (set (offset )) == 1 :
47
53
offset = Literal (offset [0 ], casting = True )
48
54
else :
49
55
offset = NamedLiteral (offset , parent = values , suffix = "offset" )
@@ -616,15 +622,18 @@ def emit_unpack_instruction(self, *,
616
622
617
623
class WrapperBuilder (object ):
618
624
619
- def __init__ (self , * , kernel , iterset , iteration_region = None , single_cell = False ,
625
+ def __init__ (self , * , kernel , subset , extruded , constant_layers , iteration_region = None , single_cell = False ,
620
626
pass_layer_to_kernel = False , forward_arg_types = ()):
621
627
self .kernel = kernel
628
+ self .local_knl_args = iter (kernel .arguments )
622
629
self .arguments = []
623
630
self .argument_accesses = []
624
631
self .packed_args = []
625
632
self .indices = []
626
633
self .maps = OrderedDict ()
627
- self .iterset = iterset
634
+ self .subset = subset
635
+ self .extruded = extruded
636
+ self .constant_layers = constant_layers
628
637
if iteration_region is None :
629
638
self .iteration_region = ALL
630
639
else :
@@ -637,18 +646,6 @@ def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
637
646
def requires_zeroed_output_arguments (self ):
638
647
return self .kernel .requires_zeroed_output_arguments
639
648
640
- @property
641
- def subset (self ):
642
- return isinstance (self .iterset , Subset )
643
-
644
- @property
645
- def extruded (self ):
646
- return self .iterset ._extruded
647
-
648
- @property
649
- def constant_layers (self ):
650
- return self .extruded and self .iterset .constant_layers
651
-
652
649
@cached_property
653
650
def loop_extents (self ):
654
651
return (Argument ((), IntType , name = "start" ),
@@ -753,94 +750,98 @@ def loop_indices(self):
753
750
return (self .loop_index , None , self ._loop_index )
754
751
755
752
def add_argument (self , arg ):
753
+ local_arg = next (self .local_knl_args )
754
+ access = local_arg .access
755
+ dtype = local_arg .dtype
756
756
interior_horizontal = self .iteration_region == ON_INTERIOR_FACETS
757
- if arg . _is_dat :
758
- if arg . _is_mixed :
759
- packs = []
760
- for a in arg :
761
- shape = a . data . shape [ 1 :]
762
- if shape == ():
763
- shape = ( 1 , )
764
- shape = ( None , * shape )
765
- argument = Argument ( shape , a . data . dtype , pfx = "mdat" )
766
- packs . append ( a . data . pack ( argument , arg . access , self . map_ ( a . map , unroll = a . unroll_map ),
767
- interior_horizontal = interior_horizontal ,
768
- init_with_zero = self . requires_zeroed_output_arguments ) )
769
- self . arguments . append ( argument )
770
- pack = MixedDatPack ( packs , arg . access , arg . dtype , interior_horizontal = interior_horizontal )
771
- self . packed_args . append ( pack )
772
- self .argument_accesses . append (arg .access )
757
+
758
+ if isinstance ( arg , GlobalKernelArg ) :
759
+ argument = Argument ( arg . dim , dtype , pfx = "glob" )
760
+
761
+ pack = GlobalPack ( argument , access ,
762
+ init_with_zero = self . requires_zeroed_output_arguments )
763
+ self . arguments . append ( argument )
764
+ elif isinstance ( arg , DatKernelArg ):
765
+ if arg . dim == ():
766
+ shape = ( None , 1 )
767
+ else :
768
+ shape = ( None , * arg . dim )
769
+ argument = Argument ( shape , dtype , pfx = "dat" )
770
+
771
+ if arg . is_indirect :
772
+ map_ = self ._add_map (arg .map_ )
773
773
else :
774
- if arg ._is_dat_view :
775
- view_index = arg .data .index
776
- data = arg .data ._parent
774
+ map_ = None
775
+ pack = arg .pack (argument , access , map_ = map_ ,
776
+ interior_horizontal = interior_horizontal ,
777
+ view_index = arg .index ,
778
+ init_with_zero = self .requires_zeroed_output_arguments )
779
+ self .arguments .append (argument )
780
+ elif isinstance (arg , MixedDatKernelArg ):
781
+ packs = []
782
+ for a in arg :
783
+ if a .dim == ():
784
+ shape = (None , 1 )
785
+ else :
786
+ shape = (None , * a .dim )
787
+ argument = Argument (shape , dtype , pfx = "mdat" )
788
+
789
+ if a .is_indirect :
790
+ map_ = self ._add_map (a .map_ )
777
791
else :
778
- view_index = None
779
- data = arg .data
780
- shape = data .shape [1 :]
781
- if shape == ():
782
- shape = (1 ,)
783
- shape = (None , * shape )
784
- argument = Argument (shape ,
785
- arg .data .dtype ,
786
- pfx = "dat" )
787
- pack = arg .data .pack (argument , arg .access , self .map_ (arg .map , unroll = arg .unroll_map ),
788
- interior_horizontal = interior_horizontal ,
789
- view_index = view_index ,
790
- init_with_zero = self .requires_zeroed_output_arguments )
792
+ map_ = None
793
+
794
+ packs .append (arg .pack (argument , access , map_ ,
795
+ interior_horizontal = interior_horizontal ,
796
+ init_with_zero = self .requires_zeroed_output_arguments ))
791
797
self .arguments .append (argument )
792
- self .packed_args .append (pack )
793
- self .argument_accesses .append (arg .access )
794
- elif arg ._is_global :
795
- argument = Argument (arg .data .dim ,
796
- arg .data .dtype ,
797
- pfx = "glob" )
798
- pack = GlobalPack (argument , arg .access ,
799
- init_with_zero = self .requires_zeroed_output_arguments )
798
+ pack = MixedDatPack (packs , access , dtype ,
799
+ interior_horizontal = interior_horizontal )
800
+ elif isinstance (arg , MatKernelArg ):
801
+ argument = Argument ((), PetscMat (), pfx = "mat" )
802
+ maps = tuple (self ._add_map (m , arg .unroll )
803
+ for m in arg .maps )
804
+ pack = arg .pack (argument , access , maps ,
805
+ arg .dims , dtype ,
806
+ interior_horizontal = interior_horizontal )
800
807
self .arguments .append (argument )
801
- self .packed_args .append (pack )
802
- self .argument_accesses .append (arg .access )
803
- elif arg ._is_mat :
804
- if arg ._is_mixed :
805
- packs = []
806
- for a in arg :
807
- argument = Argument ((), PetscMat (), pfx = "mat" )
808
- map_ = tuple (self .map_ (m , unroll = arg .unroll_map ) for m in a .map )
809
- packs .append (arg .data .pack (argument , a .access , map_ ,
810
- a .data .dims , a .data .dtype ,
811
- interior_horizontal = interior_horizontal ))
812
- self .arguments .append (argument )
813
- pack = MixedMatPack (packs , arg .access , arg .dtype ,
814
- arg .data .sparsity .shape )
815
- self .packed_args .append (pack )
816
- self .argument_accesses .append (arg .access )
817
- else :
808
+ elif isinstance (arg , MixedMatKernelArg ):
809
+ packs = []
810
+ for a in arg :
818
811
argument = Argument ((), PetscMat (), pfx = "mat" )
819
- map_ = tuple (self .map_ (m , unroll = arg .unroll_map ) for m in arg .map )
820
- pack = arg .data .pack (argument , arg .access , map_ ,
821
- arg .data .dims , arg .data .dtype ,
822
- interior_horizontal = interior_horizontal )
812
+ maps = tuple (self ._add_map (m , a .unroll )
813
+ for m in a .maps )
814
+
815
+ packs .append (arg .pack (argument , access , maps ,
816
+ a .dims , dtype ,
817
+ interior_horizontal = interior_horizontal ))
823
818
self .arguments .append (argument )
824
- self . packed_args . append ( pack )
825
- self . argument_accesses . append ( arg .access )
819
+ pack = MixedMatPack ( packs , access , dtype ,
820
+ arg .shape )
826
821
else :
827
822
raise ValueError ("Unhandled argument type" )
828
823
829
- def map_ (self , map_ , unroll = False ):
824
+ self .packed_args .append (pack )
825
+ self .argument_accesses .append (access )
826
+
827
+ def _add_map (self , map_ , unroll = False ):
830
828
if map_ is None :
831
829
return None
832
830
interior_horizontal = self .iteration_region == ON_INTERIOR_FACETS
833
831
key = map_
834
832
try :
835
833
return self .maps [key ]
836
834
except KeyError :
837
- if isinstance (map_ , PermutedMap ):
838
- imap = self .map_ (map_ .map_ , unroll = unroll )
839
- map_ = PMap (imap , map_ .permutation )
835
+ if isinstance (map_ , PermutedMapKernelArg ):
836
+ imap = self ._add_map (map_ .base_map , unroll )
837
+ map_ = PMap (imap , numpy . asarray ( map_ .permutation , dtype = IntType ) )
840
838
else :
841
- map_ = Map (map_ , interior_horizontal ,
839
+ map_ = Map (interior_horizontal ,
842
840
(self .bottom_layer , self .top_layer ),
843
- unroll = unroll )
841
+ arity = map_ .arity , offset = map_ .offset , dtype = IntType ,
842
+ unroll = unroll ,
843
+ extruded = self .extruded ,
844
+ constant_layers = self .constant_layers )
844
845
self .maps [key ] = map_
845
846
return map_
846
847
0 commit comments