Skip to content
9,860 changes: 2,314 additions & 7,546 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

64 changes: 51 additions & 13 deletions sumpy/expansion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, overload
from warnings import warn

from typing_extensions import Self, override
Expand Down Expand Up @@ -388,9 +388,24 @@ def _split_coeffs_into_hyperplanes(


class FullExpansionTermsWrangler(ExpansionTermsWrangler):
def get_storage_index(self, mi: MultiIndex, order: int | None = None) -> int:
@overload
def get_storage_index(self,
mi: MultiIndex,
order: int | None = None) -> int: ...

@overload
def get_storage_index(self,
mi: tuple[prim.ExpressionNode, ...],
order: int | prim.ExpressionNode | None = None,
) -> prim.ExpressionNode: ...

def get_storage_index(self,
mi: MultiIndex | tuple[prim.ExpressionNode, ...],
order: int | prim.ExpressionNode | None = None,
) -> int | prim.ExpressionNode:
if not order:
order = sum(mi)

if self.dim == 3:
return (order*(order + 1)*(order + 2))//6 + \
(order + 2)*mi[2] - (mi[2]*(mi[2] + 1))//2 + mi[1]
Expand Down Expand Up @@ -670,8 +685,22 @@ def get_full_coefficient_identifiers(self) -> Sequence[MultiIndex]:
key, _ = self._get_mi_ordering_key_and_axis_permutation()
return sorted(identifiers, key=key)

def get_storage_index(self, mi: MultiIndex, order: int | None = None):
if not order:
@overload
def get_storage_index(self,
mi: MultiIndex,
order: int | None = None) -> int: ...

@overload
def get_storage_index(self,
mi: tuple[prim.ExpressionNode, ...],
order: int | prim.ExpressionNode | None = None,
) -> prim.ExpressionNode: ...

def get_storage_index(self,
mi: MultiIndex | tuple[prim.ExpressionNode, ...],
order: int | prim.ExpressionNode | None = None,
) -> int | prim.ExpressionNode:
if order is None:
order = sum(mi)

ordering_key, axis_permutation = \
Expand All @@ -684,23 +713,32 @@ def get_storage_index(self, mi: MultiIndex, order: int | None = None):

c = max_mi[axis_permutation[0]]

mi = list(mi)
mi[axis_permutation[0]], mi[0] = mi[0], mi[axis_permutation[0]]
new_mi = list(mi)
new_mi[axis_permutation[0]], new_mi[0] = mi[0], mi[axis_permutation[0]]
mi = tuple(new_mi)

if self.dim == 3:
if all(isinstance(axis, int) for axis in mi):
if order < c - 1:
return (order*(order + 1)*(order + 2))//6 + \
(order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1]
return (
(order*(order + 1)*(order + 2))//6
+ (order + 2)*mi[0]
- (mi[0]*(mi[0] + 1))//2
+ mi[1])
else:
return (c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
+ mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
return (
(c*(c-1)*(c-2))//6
+ (c * order * (2 + order - c) + mi[0]*(3 - mi[0]+2*order))//2
+ mi[1])
else:
return prim.If(prim.Comparison(order, "<", c - 1),
(order*(order + 1)*(order + 2))//6
+ (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1],
(c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
+ mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
+ (order + 2)*mi[0]
- (mi[0]*(mi[0] + 1))//2
+ mi[1],
(c*(c-1)*(c-2))//6
+ (c * order * (2 + order - c) + mi[0]*(3 - mi[0]+2*order))//2
+ mi[1]
)
elif self.dim == 2:
if all(isinstance(axis, int) for axis in mi):
Expand Down
Loading
Loading