Skip to content

Conversation

deerlord-tsuno
Copy link

@deerlord-tsuno deerlord-tsuno commented Jan 10, 2025

This PR applies changes to _pow and Pow.backward to fix some bugs and expand their functionality. I've verified these changes work at least for my use case, but any refinements are welcome.

_pow(a, n)

The current implementation only supports positive integers, and has some unexpected behavior when the n value passed does not satisfy these properties. This PR changes _pow to work for all real values, as well as aligning the function's code with the other recursive functions.

Pow.backward(dz, z)

This function handles the backprop gradients for Pow.forward(a, n), which returns z based on the formula $z=a^n$ for constant $n$.

Currently, Pow.backward(dz, z) calculates $\frac{dz}{da}=2a$ regardless of the value of $n$. This is correct if the input tensor is squared as in a.pow(2), but for any other exponents the function will calculate the wrong derivative, resulting in incorrect gradients being propagated. This PR changes Pow.backward to instead calculate $\frac{dz}{da}=na^{n-1}$ as intended. To achieve this, Pow.forward(a, n) now stores n in the cache during the forward pass, which is then retrieved in the backward pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant