Description
Description
import pymc
a = 0.5
x = pm.Normal.dist()
y = pm.math.switch(x > 0, x, a * x)
pm.logp(y, 2.3).eval() # NotImplementedError
We already have a logprob derivation for mixture switches where the condition is constant, but not if it depends on the same measurable variable that is in the branches. This is not a mixture but an invertible transform.
We could support an arbitrary functions on both branches as long as the domains retain the same sign after the transformation (so that it's easy to invert). To respect this, the leaky ReLU actually requires a
to be positive, so a runtime check may be needed.
We could actually support arbitrary cutoff points, but it becomes increasingly tricky to figure out which branch to go down when inverting the graph.
In any case, because it is such a common transformation, it would great to at least support the special case of the leaky ReLu.