Exact FLOP costs for batched/broadcast contractions; unify on FMA=2#114
Merged
Conversation
…elper Route matmul/dot/inner's 2-D/1-D contraction paths through one shared _einsum_routed_binary helper that charges the FMA=2 symmetry-aware einsum accumulation cost, preserves operand aliasing, and wraps symmetric results as SymmetricTensor. Behavior-preserving: no FLOP-cost changes. The size*size fallback for batched/mixed-ndim cases is left unchanged. Add the helper to the test_overhead_coverage AST-lint exemption set, since it is a shared cost-routing helper invoked by already-decorated wrappers and must not itself carry @_counted_wrapper.
…l N-D Route the broadcast contraction ops through the shared einsum cost path so batch/broadcast axes on either operand are counted exactly, matching the equivalent fnp.einsum. Also unifies these ops on the FMA=2 convention. BREAKING CHANGE: FLOP costs change for vecmat, matvec, vecdot, and N-D/mixed matmul. Consumers that pin or budget on absolute FLOP counts should re-baseline.
N-D dot/inner contract one axis and outer-product the rest; replace the a.size*b.size / a.size*b.shape[-1] fallbacks with generated distinct-label einsum subscripts so the cost is exact. inner now wraps tracked inputs consistently with dot/matmul. BREAKING CHANGE: FLOP costs change for dot/inner with >2-D operands.
The contraction-order search called flop_count with the legacy index-set signature (FMA=1); thread per-step subscripts/shapes so it uses the same FMA=2 accumulation cost as billing, and remove the dead fallback. One cost model end-to-end. Binary einsums are unaffected (single step). BREAKING CHANGE: multi-operand einsum path selection and billed totals may change where FMA=2 vs FMA=1 flips the cheapest order.
The deletion-safety tests intentionally import the local _paths/_path_random submodules to assert they exist. That registers them in the flopscope._opt_einsum package __dict__, permanently shadowing __init__.py's lazy __getattr__ hook that maps oe._paths/oe._path_random to the upstream opt_einsum modules. The leak broke tests run later in the same process — the custom-optimizer tests in test_opt_einsum_paths.py rely on oe._path_random being upstream (otherwise isinstance(optimize, PathOptimizer) flips to False, the optimizer is forwarded to upstream, and it raises "TypeError: 'RandomGreedy' object is not iterable"). xdist masked this (tests land on different workers); a serial run (-n 0) deterministically exposed 5 failures. Add an autouse teardown that drops the _paths/_path_random shadows after each test (sys.modules left intact to preserve class identity; _helpers is intentionally kept local), plus a regression test for the restoration invariant.
These were committed in the FMA=2 path-search change without ruff format; no logic changes (whitespace only). Restores a clean `ruff format --check`.
Billing has used the FMA=2 textbook convention (multiplies and adds counted separately) for a while; these labels still said FMA=1. Correct them and drop references to the removed fma_cost setting. No FLOP numbers change. - _flops.py matmul_cost, _polynomial.py polyval, _pointwise.py convolve/correlate docstrings: FMA=1 -> FMA=2. - _opt_einsum __init__/NOTICE/_contract docstrings: FMA=2 via the accumulation model; state there is no fma_cost setting (opt_cost labelled as the upstream opt_einsum convention). - data/weights.csv contraction notes: FMA=2 with the exact billed formulas (2*M*K*N - M*N for dot/matmul, 2*N - 1 for inner/vdot).
Routing matmul/dot/inner/vecmat/matvec/vecdot through the einsum
accumulation cost surfaced that _build_size_map rejected a shared label
appearing with sizes {1, N} — but that is NumPy broadcasting (the size-1
axis broadcasts to N). The same gap affected fnp.einsum with an explicit
size-1 broadcast batch axis.
Treat a size-1 axis as broadcastable: a label's size is the broadcast
extent (the non-1 value); only a mismatch where neither size is 1 is a
genuine inconsistency. The change is additive — it only converts
previously-raised broadcast errors into the correct (broadcasted) cost;
inputs whose label sizes already agreed are unaffected, and the
off-by-one output-orbit credit is applied per broadcasted output.
Fixes numpy-compat failures: numpy's own test_ufunc::test_output_argument
and ::test_axis_argument exercise np.vecdot with a size-1 batch operand.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Routes the binary-contraction family —
matmul,dot,inner,vecmat,matvec,vecdot—through a single shared cost path built on the symmetry-aware einsum accumulation model, so per-op
FLOP counts are exact for all operand layouts (batched, broadcast, mixed-rank, 1-D-promoted),
not just the 2-D case. Previously these ops used per-op fallback formulas for non-2-D inputs that
did not account for all batch/broadcast axes; counting now derives from the operation's einsum
contraction structure, so it matches
fnp.einsum(<equivalent subscripts>)exactly. Thecontraction-order path search is unified onto the same FMA=2 convention used for billing,
and the legacy cost fallback is removed — one cost model end-to-end.
What changes
_einsum_routed_binary— builds the op's einsum subscripts and routes cost +output-symmetry inference through
_resolve_cost_and_output_symmetry(the pathmatmul/dot2-Dalready used). All six wrappers use it.
vecmat/matvec/vecdotnow count batch/broadcast axes exactly and bill in FMA=2.matmul/dotN-D & mixed-rank andinnerN-D route through einsum instead of thea.size * b.sizefallback.the legacy FMA=1 fallback is removed.
linalg.lstsqusesmatmul_costdirectly (its 2-D×1-D workaround is no longer needed).fma_costsetting.Gaming-resistance
Complements the existing no-gaming property (symmetric cost ≤ dense) with the dual guard: contraction
cost cannot be under-counted by re-expressing a matmul as a batched vector op. New parity tests
assert every op equals its
einsumequivalent across batched/broadcast/mixed-rank shapes and thatcost scales with the batch dimension.
Breaking change
FLOP costs change for the affected ops (exactness + FMA=2 unification). Consumers that pin or budget
on absolute FLOP counts should re-baseline.
A @ Asymmetric behavior (symmetry-aware cost +SymmetricTensoroutput) and all 2-D costs already routed through einsum are unchanged.Out of scope
tensordotpartial-contraction symmetry path (keepsdirect_product_groups).A @ A.Tsymmetry detection for non-symmetricA(write aseinsum).outer/kron/vdot(no contracted axis; already exact).Test plan
A@Asymmetry preserved.test_cost_formula_vs_code.py,test_issue_69_cost_parity.py[lstsq],test_fma_unification.py, PathInfo snapshots.Test-suite robustness (no product code change)
tests/accumulation/test_deletion_safety.pynow restores the package's lazy upstream__getattr__shim after it imports the local_paths/_path_randomsubmodules. Without this,those imports leaked into the package
__dict__and shadowed the shim for the rest of theprocess, causing the custom/random-optimizer tests in
test_opt_einsum_paths.pyto fail underserial (
-n 0) execution (they passed under xdist, which distributes the tests across workers).