Skip to content

Commit 07b3f8e

Browse files
committed
Try to run full test suite in Numba backend
1 parent 1935809 commit 07b3f8e

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ jobs:
141141
shell: bash -l {0}
142142
run: |
143143
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
144-
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
144+
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"
145145
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
146146
pip install -e ./
147147
mamba list && pip freeze

pytensor/compile/mode.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def register_linker(name, linker):
6161
# If a string is passed as the optimizer argument in the constructor
6262
# for Mode, it will be used as the key to retrieve the real optimizer
6363
# in this dictionary
64-
exclude = []
65-
if not config.cxx:
66-
exclude = ["cxx_only"]
64+
65+
exclude = ["cxx_only", "BlasOpt"]
6766
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
6867
# Even if multiple merge optimizer call will be there, this shouldn't
6968
# impact performance.
@@ -437,16 +436,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
437436
# string as the key
438437
# Use VM_linker to allow lazy evaluation by default.
439438
FAST_COMPILE = Mode(
440-
VMLinker(use_cloop=False, c_thunks=False),
441-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
439+
NumbaLinker(),
440+
# TODO: Fast_compile should just use python code, CHANGE ME!
441+
RewriteDatabaseQuery(
442+
include=["fast_compile", "numba"],
443+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
444+
),
445+
)
446+
FAST_RUN = Mode(
447+
NumbaLinker(),
448+
RewriteDatabaseQuery(
449+
include=["fast_run", "numba"],
450+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
451+
),
442452
)
443-
if config.cxx:
444-
FAST_RUN = Mode("cvm", "fast_run")
445-
else:
446-
FAST_RUN = Mode(
447-
"vm",
448-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
449-
)
450453

451454
JAX = Mode(
452455
JAXLinker(),
@@ -512,7 +515,7 @@ def get_mode(orig_string):
512515
# NanGuardMode use its own linker.
513516
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
514517
else:
515-
# TODO: Can't we look up the name and invoke it rather than using eval here?
518+
# TODO: Get rid of this? Or refactor?
516519
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
517520
elif string in predefined_modes:
518521
ret = predefined_modes[string]
@@ -541,6 +544,7 @@ def register_mode(name, mode):
541544
Add a `Mode` which can be referred to by `name` in `function`.
542545
543546
"""
547+
# TODO: Remove me
544548
if name in predefined_modes:
545549
raise ValueError(f"Mode name already taken: {name}")
546550
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,18 @@ def add_compile_configvars():
461461
"linker",
462462
"Default linker used if the pytensor flags mode is Mode",
463463
EnumStr(
464-
"cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
464+
"numba",
465+
[
466+
"cvm",
467+
"jax",
468+
"c|py",
469+
"py",
470+
"c",
471+
"c|py_nogc",
472+
"vm",
473+
"vm_nogc",
474+
"cvm_nogc",
475+
],
465476
),
466477
in_c_key=False,
467478
)
@@ -471,7 +482,7 @@ def add_compile_configvars():
471482
config.add(
472483
"linker",
473484
"Default linker used if the pytensor flags mode is Mode",
474-
EnumStr("vm", ["py", "vm_nogc"]),
485+
EnumStr("numba", ["vm", "jax", "py", "vm_nogc"]),
475486
in_c_key=False,
476487
)
477488
if type(config).cxx.is_default:

0 commit comments

Comments
 (0)