Skip to content

Derive logprob of Leaky ReLU transform #7543

Open
@ricardoV94

Description

@ricardoV94

Description

import pymc

a = 0.5
x = pm.Normal.dist()
y = pm.math.switch(x > 0, x, a * x)
pm.logp(y, 2.3).eval()  # NotImplementedError

We already have a logprob derivation for mixture switches where the condition is constant, but not if it depends on the same measurable variable that is in the branches. This is not a mixture but an invertible transform.

We could support an arbitrary functions on both branches as long as the domains retain the same sign after the transformation (so that it's easy to invert). To respect this, the leaky ReLU actually requires a to be positive, so a runtime check may be needed.

We could actually support arbitrary cutoff points, but it becomes increasingly tricky to figure out which branch to go down when inverting the graph.

In any case, because it is such a common transformation, it would great to at least support the special case of the leaky ReLu.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions