1919OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2020THE SOFTWARE.
2121"""
22+
23+ import sys
2224from typing import (Any , Dict , Mapping , Tuple , TypeAlias , Iterable ,
2325 FrozenSet , Union , Set , List , Optional , Callable )
2426from pytato .array import (Array , InputArgumentBase , DictOfNamedArrays ,
3840from immutables import Map
3941from pytato .utils import are_shape_components_equal
4042
43+ if sys .version >= (3 , 11 ):
44+ zip_equal = lambda * _args : zip (* _args , strict = True )
45+ else :
46+ from more_itertools import zip_equal
47+
4148_ComposedIndirectionT : TypeAlias = Tuple [Array , ...]
4249IndexT : TypeAlias = Union [Array , NormalizedSlice ]
43- IndexStackT : TypeAlias = Tuple [IndexT , ...]
4450
4551
4652def _is_materialized (expr : Array ) -> bool :
@@ -53,15 +59,15 @@ def _is_materialized(expr: Array) -> bool:
5359 or bool (expr .tags_of_type (ImplStored )))
5460
5561
56- def _is_trivial_slice (dim : ShapeComponent , slice_ : NormalizedSlice ) -> bool :
62+ def _is_trivial_slice (dim : ShapeComponent , slice_ : IndexT ) -> bool :
5763 """
5864 Returns *True* only if *slice_* indexes an entire axis of shape *dim* with
5965 a step of 1.
6066 """
61- return (slice_ .step == 1
67+ return (isinstance (slice_ , NormalizedSlice )
68+ and slice_ .step == 1
6269 and are_shape_components_equal (slice_ .start , 0 )
63- and are_shape_components_equal (slice_ .stop , dim )
64- )
70+ and are_shape_components_equal (slice_ .stop , dim ))
6571
6672
6773def _take_along_axis (ary : Array , iaxis : int , idxs : IndexStackT ) -> Array :
@@ -427,35 +433,35 @@ class _IndirectionPusher(Mapper):
427433
428434 def __init__ (self ) -> None :
429435 self .get_reordarable_axes = _LegallyAxisReorderingFinder ()
430- self ._cache : Dict [Tuple [ArrayOrNames , Map [int , IndexStackT ]],
436+ self ._cache : Dict [Tuple [ArrayOrNames , Map [int , IndexT ]],
431437 ArrayOrNames ] = {}
432438 super ().__init__ ()
433439
434440 def rec (self , # type: ignore[override]
435441 expr : MappedT ,
436- index_stacks : Map [int , IndexStackT ]) -> MappedT :
437- key = (expr , index_stacks )
442+ indices : Tuple [IndexT , ...]) -> MappedT :
443+ assert len (indices ) == expr .ndim
444+ key = (expr , indices )
438445 try :
439446 # type-ignore-reason: parametric mapping types aren't a thing in 'typing'
440447 return self ._cache [key ] # type: ignore[return-value]
441448 except KeyError :
442- result = Mapper .rec (self , expr , index_stacks )
449+ result = Mapper .rec (self , expr , indices )
443450 self ._cache [key ] = result
444451 return result # type: ignore[no-any-return]
445452
446453 def __call__ (self , # type: ignore[override]
447454 expr : MappedT ,
448- index_stacks : Map [int , IndexStackT ]) -> MappedT :
449- return self .rec (expr , index_stacks )
455+ indices : Map [int , IndexT ]) -> MappedT :
456+ return self .rec (expr , indices )
450457
451458 def _map_materialized (self ,
452459 expr : Array ,
453- index_stacks : Map [int , IndexStackT ]) -> Array :
454- result = expr
455- for iaxis , idxs in index_stacks .items ():
456- result = _take_along_axis (result , iaxis , idxs )
457-
458- return result
460+ indices : Tuple [IndexT , ...]) -> Array :
461+ if all (_is_trivial_slice (dim , idx )
462+ for dim , idx in zip (expr .shape , indices )):
463+ return expr
464+ return expr [* indices ]
459465
460466 def map_dict_of_named_arrays (self ,
461467 expr : DictOfNamedArrays ,
@@ -467,9 +473,12 @@ def map_dict_of_named_arrays(self,
467473
468474 def map_index_lambda (self ,
469475 expr : IndexLambda ,
470- index_stacks : Map [ int , IndexStackT ]
476+ indices : Tuple [ IndexT , ...],
471477 ) -> Array :
472478 if _is_materialized (expr ):
479+ # FIXME: Move this logic to .rec (Why on earth do we need)
480+ # to copy the damn node???
481+
473482 # do not propagate the indexings to the bindings.
474483 expr = IndexLambda (expr .expr ,
475484 expr .shape ,
@@ -478,9 +487,13 @@ def map_index_lambda(self,
478487 for name , bnd in expr .bindings .items ()}),
479488 expr .var_to_reduction_descr ,
480489 tags = expr .tags ,
481- axes = expr .axes ,
482- )
483- return self ._map_materialized (expr , index_stacks )
490+ axes = expr .axes ,)
491+ return self ._map_materialized (expr , indices )
492+
493+ # FIXME:
494+ # This is the money shot. Over here we need to figure out the index
495+ # propagation logic.
496+
484497
485498 iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis (expr )
486499
@@ -886,128 +899,13 @@ def push_axis_indirections_towards_materialized_nodes(expr: MappedT
886899 ) -> MappedT :
887900 """
888901 Returns a copy of *expr* with the indirections propagated closer to the
889- materialized nodes. We propagate an indirections only if the indirection in
890- an :class:`~pytato.array.AdvancedIndexInContiguousAxes` or
891- :class:`~pytato.array.AdvancedIndexInNoncontiguousAxes` is an indirection
892- over a single axis.
902+ materialized nodes.
893903 """
894904 mapper = _IndirectionPusher ()
895905
896906 return mapper (expr , Map ())
897907
898908
899- def _get_unbroadcasted_axis_in_indirections (
900- expr : AdvancedIndexInContiguousAxes ) -> Optional [Mapping [int , int ]]:
901- """
902- Returns a mapping from the index of an indirection to its *only*
903- unbroadcasted axis as required by the logic. Returns *None* if no such
904- mapping exists.
905- """
906- from pytato .utils import partition , get_shape_after_broadcasting
907- adv_indices , _ = partition (lambda i : isinstance (expr .indices [i ],
908- NormalizedSlice ),
909- range (expr .array .ndim ))
910- i_ary_indices = [i_idx
911- for i_idx , idx in enumerate (expr .indices )
912- if isinstance (idx , Array )]
913-
914- adv_idx_shape = get_shape_after_broadcasting ([expr .indices [i_idx ]
915- for i_idx in adv_indices ])
916-
917- if len (adv_idx_shape ) != len (i_ary_indices ):
918- return None
919-
920- i_adv_out_axis_to_candidate_i_arys : Dict [int , Set [int ]] = {
921- idim : set ()
922- for idim , _ in enumerate (adv_idx_shape )
923- }
924-
925- for i_ary_idx in i_ary_indices :
926- ary = expr .indices [i_ary_idx ]
927- assert isinstance (ary , Array )
928- for iadv_out_axis , i_ary_axis in zip (range (len (adv_idx_shape )- 1 , - 1 , - 1 ),
929- range (ary .ndim - 1 , - 1 , - 1 )):
930- if are_shape_components_equal (adv_idx_shape [iadv_out_axis ],
931- ary .shape [i_ary_axis ]):
932- i_adv_out_axis_to_candidate_i_arys [iadv_out_axis ].add (i_ary_idx )
933-
934- from itertools import permutations
935- # FIXME: O(expr.ndim!) complexity, typically ndim <= 4 so this should be fine.
936- for guess_i_adv_out_axis_to_i_ary in permutations (range (len (i_ary_indices ))):
937- if all (i_ary in i_adv_out_axis_to_candidate_i_arys [i_adv_out ]
938- for i_adv_out , i_ary in enumerate (guess_i_adv_out_axis_to_i_ary )):
939- # TODO: Return the mapping here...
940- i_ary_to_unbroadcasted_axis : Dict [int , int ] = {}
941- for guess_i_adv_out_axis , i_ary_idx in enumerate (
942- guess_i_adv_out_axis_to_i_ary ):
943- ary = expr .indices [i_ary_idx ]
944- assert isinstance (ary , Array )
945- iunbroadcasted_axis , = [
946- i_ary_axis
947- for i_adv_out_axis , i_ary_axis in zip (
948- range (len (adv_idx_shape )- 1 , - 1 , - 1 ),
949- range (ary .ndim - 1 , - 1 , - 1 ))
950- if i_adv_out_axis == guess_i_adv_out_axis
951- ]
952- i_ary_to_unbroadcasted_axis [i_ary_idx ] = iunbroadcasted_axis
953-
954- return Map (i_ary_to_unbroadcasted_axis )
955-
956- return None
957-
958-
959- class MultiAxisIndirectionsDecoupler (CopyMapper ):
960- def map_contiguous_advanced_index (self ,
961- expr : AdvancedIndexInContiguousAxes
962- ) -> Array :
963- i_ary_idx_to_unbroadcasted_axis = _get_unbroadcasted_axis_in_indirections (
964- expr )
965-
966- if i_ary_idx_to_unbroadcasted_axis is not None :
967- from pytato .utils import partition
968- i_adv_indices , _ = partition (lambda idx : isinstance (expr .indices [idx ],
969- NormalizedSlice ),
970- range (len (expr .indices )))
971-
972- result = self .rec (expr .array )
973-
974- for iaxis , idx in enumerate (expr .indices ):
975- if isinstance (idx , Array ):
976- from pytato .array import squeeze
977- axes_to_squeeze = [
978- idim
979- for idim in range (expr
980- .indices [iaxis ] # type: ignore[union-attr]
981- .ndim )
982- if idim != i_ary_idx_to_unbroadcasted_axis [iaxis ]]
983- if axes_to_squeeze :
984- idx = squeeze (idx , axis = axes_to_squeeze )
985- if not (isinstance (idx , NormalizedSlice )
986- and _is_trivial_slice (expr .array .shape [iaxis ], idx )):
987- result = result [
988- (slice (None ),) * iaxis + (idx , )] # type: ignore[operator]
989-
990- return result
991- else :
992- return super ().map_contiguous_advanced_index (expr )
993-
994-
995- def decouple_multi_axis_indirections_into_single_axis_indirections (
996- expr : MappedT ) -> MappedT :
997- """
998- Returns a copy of *expr* with multiple indirections in an
999- :class:`~pytato.array.AdvancedIndexInContiguousAxes` decoupled as a
1000- composition of indexing nodes with single-axis indirections.
1001-
1002- .. note::
1003-
1004- This is a dependency preserving transformation. If a decoupling an
1005- advanced indexing node is not legal, we leave the node unmodified.
1006- """
1007- mapper = MultiAxisIndirectionsDecoupler ()
1008- return mapper (expr )
1009-
1010-
1011909# {{{ fold indirection constants
1012910
1013911class _ConstantIndirectionArrayCollector (CombineMapper [FrozenSet [Array ]]):
0 commit comments