Skip to content

Base expression node types on dataclasses #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ on:
schedule:
- cron: '17 3 * * 0'

concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true

jobs:
ruff:
name: Ruff
Expand Down Expand Up @@ -44,6 +48,23 @@ jobs:

run_pylint pymbolic test/test_*.py

mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
-
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy numpy
./run-mypy.sh

pytest:
name: Pytest on Py${{ matrix.python-version }}
runs-on: ubuntu-latest
Expand Down
12 changes: 12 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ Pylint:
except:
- tags

Mypy:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh
tags:
- python3
except:
- tags

Documentation:
script:
- EXTRA_INSTALL="numpy sympy"
Expand Down
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3", None),
"sympy": ("https://docs.sympy.org/dev/", None),
"typing_extensions":
("https://typing-extensions.readthedocs.io/en/latest/", None),
}
16 changes: 7 additions & 9 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,11 @@ You can also easily define your own objects to use inside an expression:

.. doctest::

>>> from pymbolic.primitives import Expression
>>> class FancyOperator(Expression):
... def __init__(self, operand):
... self.operand = operand
...
... def __getinitargs__(self):
... return (self.operand,)
...
... mapper_method = "map_fancy_operator"
>>> from pymbolic.primitives import Expression, expr_dataclass
>>>
>>> @expr_dataclass()
... class FancyOperator(Expression):
... operand: Expression
...
>>> u
Power(Sum((Variable('x'), 1)), 5)
Expand All @@ -89,6 +85,8 @@ As a final example, we can now derive from *MyMapper* to multiply all

.. doctest::

>>> FancyOperator.mapper_method
'map_fancy_operator'
>>> class MyMapper2(MyMapper):
... def map_fancy_operator(self, expr):
... return 2*FancyOperator(self.rec(expr.operand))
Expand Down
12 changes: 7 additions & 5 deletions experiments/traversal-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import sys

from pymbolic import parse
from pymbolic.primitives import Variable
from pymbolic.mapper import CachedIdentityMapper
from pymbolic.mapper.optimize import optimize_mapper
from pymbolic.primitives import Variable


code = ("(-1)*((cse_577[_pt_data_48[((iface_ensm15*1075540 + iel_ensm15*10 + idof_ensm15) % 4302160) // 10, 0],"
Expand Down Expand Up @@ -128,9 +128,11 @@ def main():
t_end = time()
print(f"Took: {t_end-t_start} secs.")
else:
import vmprof
with open("test.prof", "w+b") as fd:
vmprof.enable(fd.fileno())
import pyinstrument
from pyinstrument.renderers import SpeedscopeRenderer
prof = pyinstrument.Profiler()
with prof:
for _ in range(10_000):
main()
vmprof.disable()
with open("ss.json", "w") as outf:
outf.write(prof.output(SpeedscopeRenderer(show_all=True)))
107 changes: 72 additions & 35 deletions pymbolic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand All @@ -23,38 +26,72 @@

from pymbolic.version import VERSION_TEXT as __version__ # noqa

import pymbolic.parser
import pymbolic.compiler

import pymbolic.mapper.evaluator
import pymbolic.mapper.stringifier
import pymbolic.mapper.dependency
import pymbolic.mapper.substitutor
import pymbolic.mapper.differentiator
import pymbolic.mapper.distributor
import pymbolic.mapper.flattener
import pymbolic.primitives

from pymbolic.polynomial import Polynomial # noqa

var = pymbolic.primitives.Variable
variables = pymbolic.primitives.variables
flattened_sum = pymbolic.primitives.flattened_sum
subscript = pymbolic.primitives.subscript
flattened_product = pymbolic.primitives.flattened_product
quotient = pymbolic.primitives.quotient
linear_combination = pymbolic.primitives.linear_combination
cse = pymbolic.primitives.make_common_subexpression
make_sym_vector = pymbolic.primitives.make_sym_vector

disable_subscript_by_getitem = pymbolic.primitives.disable_subscript_by_getitem

parse = pymbolic.parser.parse
evaluate = pymbolic.mapper.evaluator.evaluate
evaluate_kw = pymbolic.mapper.evaluator.evaluate_kw
compile = pymbolic.compiler.compile
substitute = pymbolic.mapper.substitutor.substitute
diff = differentiate = pymbolic.mapper.differentiator.differentiate
expand = pymbolic.mapper.distributor.distribute
distribute = pymbolic.mapper.distributor.distribute
flatten = pymbolic.mapper.flattener.flatten
from . import parser
from . import compiler

from .mapper import evaluator
from .mapper import stringifier
from .mapper import dependency
from .mapper import substitutor
from .mapper import differentiator
from .mapper import distributor
from .mapper import flattener
from . import primitives

from .polynomial import Polynomial

from .primitives import Variable as var # noqa: N813
from .primitives import variables
from .primitives import flattened_sum
from .primitives import subscript
from .primitives import flattened_product
from .primitives import quotient
from .primitives import linear_combination
from .primitives import make_common_subexpression as cse
from .primitives import make_sym_vector
from .primitives import disable_subscript_by_getitem
from .parser import parse
from .mapper.evaluator import evaluate
from .mapper.evaluator import evaluate_kw
from .compiler import compile
from .mapper.substitutor import substitute
from .mapper.differentiator import differentiate as diff
from .mapper.differentiator import differentiate
from .mapper.distributor import distribute as expand
from .mapper.distributor import distribute
from .mapper.flattener import flatten


__all__ = (
"Polynomial",
"compile",
"compiler",
"cse",
"dependency",
"diff",
"differentiate",
"differentiator",
"disable_subscript_by_getitem",
"distribute",
"distributor",
"evaluate",
"evaluate_kw",
"evaluator",
"expand",
"flatten",
"flattened_product",
"flattened_sum",
"flattener",
"linear_combination",
"make_sym_vector",
"parse",
"parser",
"primitives",
"quotient",
"stringifier",
"subscript",
"substitute",
"substitutor",
"var",
"variables",
)
5 changes: 4 additions & 1 deletion pymbolic/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand Down Expand Up @@ -223,7 +226,7 @@ def sym_fft(x, sign=1):
wrappers at opportune points.
"""

from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper

class NearZeroKiller(CSECachingMapperMixin, IdentityMapper):
map_common_subexpression_uncached = \
Expand Down
8 changes: 5 additions & 3 deletions pymbolic/compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand All @@ -21,10 +24,9 @@
"""

import math

import pymbolic
from pymbolic.mapper.stringifier import (
StringifyMapper, PREC_NONE,
PREC_SUM, PREC_POWER)
from pymbolic.mapper.stringifier import PREC_NONE, PREC_POWER, PREC_SUM, StringifyMapper


class CompileMapper(StringifyMapper):
Expand Down
4 changes: 4 additions & 0 deletions pymbolic/cse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand All @@ -23,6 +26,7 @@
import pymbolic.primitives as prim
from pymbolic.mapper import IdentityMapper, WalkMapper


COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)


Expand Down
3 changes: 3 additions & 0 deletions pymbolic/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand Down
7 changes: 6 additions & 1 deletion pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

__license__ = """
Expand All @@ -20,9 +23,10 @@
THE SOFTWARE.
"""

from pytools import memoize, memoize_method
import numpy as np

from pytools import memoize, memoize_method


__doc__ = """
See `Wikipedia <https://en.wikipedia.org/wiki/Geometric_algebra>`__ for an idea
Expand Down Expand Up @@ -536,6 +540,7 @@ def __init__(self, data, space=None):
# {{{ normalize data to bitmaps, if needed

from pytools import single_valued

from pymbolic.primitives import is_zero
if data and single_valued(isinstance(k, tuple) for k in data.keys()):
# data is in non-normalized non-bits tuple form
Expand Down
36 changes: 19 additions & 17 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"

__license__ = """
Expand All @@ -23,27 +26,26 @@
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.

from typing import ClassVar, Dict
from typing import ClassVar

from pymbolic.geometric_algebra import MultiVector
import pymbolic.geometric_algebra.primitives as prim
from pymbolic.geometric_algebra import MultiVector
from pymbolic.mapper import (
CombineMapper as CombineMapperBase,
Collector as CollectorBase,
IdentityMapper as IdentityMapperBase,
WalkMapper as WalkMapperBase,
CachedMapper,
)
CachedMapper,
Collector as CollectorBase,
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
WalkMapper as WalkMapperBase,
)
from pymbolic.mapper.constant_folder import (
ConstantFoldingMapper as ConstantFoldingMapperBase)
from pymbolic.mapper.graphviz import (
GraphvizMapper as GraphvizMapperBase)
ConstantFoldingMapper as ConstantFoldingMapperBase,
)
from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.graphviz import GraphvizMapper as GraphvizMapperBase
from pymbolic.mapper.stringifier import (
StringifyMapper as StringifyMapperBase,
PREC_NONE
)
from pymbolic.mapper.evaluator import (
EvaluationMapper as EvaluationMapperBase)
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)


class IdentityMapper(IdentityMapperBase):
Expand Down Expand Up @@ -105,7 +107,7 @@ def map_derivative_source(self, expr):


class StringifyMapper(StringifyMapperBase):
AXES: ClassVar[Dict[int, str]] = {0: "x", 1: "y", 2: "z"}
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}

def map_nabla(self, expr, enclosing_prec):
return f"∇[{expr.nabla_id}]"
Expand Down
Loading
Loading