Skip to content

Gradient of scan fails when it involves a shared variable #555

@jessegrabowski

Description

@jessegrabowski
Member

Before

Currently, this graph has valid gradients with respect to mu and sigma:

mu = pt.dscalar('mu')
sigma = pt.dscalar('sigma')

epsilon = pt.random.normal(0, 1)
z = mu + sigma * epsilon

pt.grad(z, sigma).eval({mu:1, sigma:1})
# Out: Random draw from a N(0, 1)

But this graph does not:

def step(x, mu, sigma, rng):
    epsilon = pt.random.normal(0, 1, rng=rng)
    next_x = x + mu + sigma * epsilon
    return next_x, {rng:new_rng}

traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[mu, sigma, rng], n_steps=10)
pt.grad(traj[-1], sigma).eval({mu:1, sigma:1, x0:0})
# Out: Error, graph depends on a shared variable

After

I imagine that in cases where the "reparameterization trick" is used, stochastic gradients can be computed for scan graphs.

Context for the issue:

The "reparameterization trick" is well known in the machine learning literature as a way to get stochastic gradients from graphs with sampling operations. It seems like we already support this, because this graph can be differentiated:

epsilon = pt.random.normal(0, 1)
z = mu + sigma * epsilon

pt.grad(z, sigma).eval({mu:1, sigma:1})

But this graph cannot:

z= pt.random.normal(mu, sigma)
pt.grad(z, sigma).eval({mu:1, sigma:1})

The fact that even the "good" version breaks down in scan is I suppose a bug? Or a missing feature? Or neither? In the equation:

x t + 1 = x t + μ + σ ϵ t
with x 0 given, it seems like:

x 2 σ = σ x 0 + μ + σ ϵ 1 + μ + σ ϵ 2 = ϵ 1 + ϵ 2

I should get back the sum of the random draws for the sequence.

Context: I'm trying to use pytensor to compute greeks for options, which involves taking the derivative of sampled trajectories.

Activity

jessegrabowski

jessegrabowski commented on Dec 13, 2023

@jessegrabowski
MemberAuthor

This works:

def step(x, epsilon, mu, sigma):
    next_x = x + mu + sigma * epsilon
    return next_x

rng = pytensor.shared(np.random.default_rng())
new_rng, epsilons = pt.random.normal(size=10, rng=rng)
traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[mu, sigma], sequences=[epsilons])

df_ds = pt.grad(traj[-1], sigma)
f = pytensor.function([x0, mu, sigma], df_ds, updates={rng:new_rng})

So as long as I can write the sequence as conditionally independent it works? It seems like it should be possible to get the gradient without doing that, though.

ricardoV94

ricardoV94 commented on Dec 14, 2023

@ricardoV94
Member

This is a case I am not sure we want to be too clever for the user sake. If you're taking gradients in stochastic graphs perhaps you should know exactly what you're doing and do the reparametrization yourself (you can create your own suite of rewrites to change the graph before calling grad).

Note we never do any random rewrites by default (other than when certain distributions are missing in a backend) because depending on how the random generator routine is used it can alter the results. This is a decision Theano devs took for reproducibility/ease of debug that we can revisit, but should do so consciously.

jessegrabowski

jessegrabowski commented on Dec 14, 2023

@jessegrabowski
MemberAuthor

I guess the tags I chose for this issue are quite bad because I don't think I want any kind of special automatic handling here. More that it seems like when scan is constructing it's gradient, it is failing because it's asking the random generator for a gradient, which it (obviously) doesn't have. Shouldn't these have a pass-through? If there's another complication (because an actual random variable -- NOT a generator -- is on the backwards graph) it can and should still error, I agree.

ricardoV94

ricardoV94 commented on Dec 14, 2023

@ricardoV94
Member

Can you provide a full example? Your original one has new_rng but but that's not defined anywhere. strict=True may help.

The only issue with scans with gradient stuff I know is that they must be passed explicitly: #6

jessegrabowski

jessegrabowski commented on Dec 14, 2023

@jessegrabowski
MemberAuthor
mu = pt.dscalar('mu')
sigma = pt.dscalar('sigma')
x0 = pt.dscalar('x0')
rng = pytensor.shared(np.random.default_rng(), 'rng')

def step(x, mu, sigma, rng):
    new_rng, epsilon = pm.Normal.dist(0, 1, rng=rng).owner.outputs
    next_x = x + mu + sigma * epsilon
    return next_x, {rng:new_rng}

traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[mu, sigma, rng], n_steps=10)
pt.grad(traj[-1], sigma)

Gives:

Traceback


NullTypeGradError Traceback (most recent call last)
Cell In[132], line 10
7 return next_x, {rng:new_rng}
9 traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[mu, sigma, rng], n_steps=10)
---> 10 pt.grad(traj[-1], sigma).eval({mu:1, sigma:1, x0:0})

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py:616, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
614 if isinstance(_rval[i].type, NullType):
615 if null_gradients == "raise":
--> 616 raise NullTypeGradError(
617 f"grad encountered a NaN. {_rval[i].type.why_null}"
618 )
619 else:
620 assert null_gradients == "return"

NullTypeGradError: grad encountered a NaN. This variable is Null because the grad method for input 3 (mu) of the Scan{scan_fn, while_loop=False, inplace=none} op is not implemented. Depends on a shared variable

ricardoV94

ricardoV94 commented on Dec 15, 2023

@ricardoV94
Member

This seems to be an old known bug/limitation of Scan: https://groups.google.com/g/theano-users/c/dAwr1j8-QOY/m/8fmDmQPkPJkJ

Maybe something we can also address better in the Scan refactor, since we don't treat shared variables as magical entities anymore.

changed the title [-]ENH: Stochastic gradients of `scan` graphs[/-] [+]Gradient of scan fails when it involves a shared variable[/+] on Dec 15, 2023
jessegrabowski

jessegrabowski commented on Dec 15, 2023

@jessegrabowski
MemberAuthor

First refactor all of pytensor (and pymc) to remove shared variables? :D

ricardoV94

ricardoV94 commented on Dec 15, 2023

@ricardoV94
Member

Shared variables are fine-ish, it's the treating them differently in PyTensor internals that's a source of unnecessary complexity.

This special treatment is also a thorn in OpFromGraph: #473

Which led me to basically reimplement it completely to be usable in PyMC: pymc-devs/pymc#6947

jessegrabowski

jessegrabowski commented on Dec 15, 2023

@jessegrabowski
MemberAuthor

I just read #473, so your thoughts on OpFromGraph and shared variables are fresh in my mind.

ricardoV94

ricardoV94 commented on Dec 15, 2023

@ricardoV94
Member

In sum, I think shared variables should be explicit inputs everywhere except in the outer PyTensor function where there is an ambiguity of whether the call signature would require them or not

ricardoV94

ricardoV94 commented on Jan 6, 2024

@ricardoV94
Member

I reopened, until we have a solution I think it's good to track the issue

ricardoV94

ricardoV94 commented on Feb 14, 2025

@ricardoV94
Member

This issue remains open after #1207

Here is a minimal reproducible example :

import pytensor
import pytensor.tensor as pt
import numpy as np

sigma = pt.dscalar('sigma')
x0 = pt.dscalar('x0')
rng = pytensor.shared(np.random.default_rng(), 'rng')

def step(x, sigma, rng):
    new_rng, epsilon = pt.random.normal(0, 1, rng=rng).owner.outputs
    next_x = x + sigma * epsilon
    return next_x, {rng:new_rng}

traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[sigma, rng], n_steps=10)
pt.grad(traj[-1], sigma)

Scan is too worried about shared variables because it doesn't scan them (it just returns the last value), and in some cases when it sees gradients that touch however slightly shared variables it returns Null. In this case, since the RV could be legally split from the Scan that computes x and passed as a sequence, it's definitely legal to compute it's gradient.

I suspect we need to tweak something more in the 770 line long implementation of Scan.L_op :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @ricardoV94@jessegrabowski

        Issue actions

          Gradient of scan fails when it involves a shared variable · Issue #555 · pymc-devs/pytensor