Open
Description
I think in the lltm_backward
function in C++
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
should be
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/false);
I also think the class LLTMFunction
inheriting from torch.autograd.Function
contains two errors.
class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights, old_cell]
ctx.save_for_backward(*variables)
return new_h, new_cell
@staticmethod
def backward(ctx, grad_h, grad_cell):
outputs = lltm.backward(
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs
return d_input, d_weights, d_bias, d_old_h, d_old_cell
should be
class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)
return new_h, new_cell
@staticmethod
def backward(ctx, grad_h, grad_cell):
outputs = lltm.backward(
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors)
d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
return d_input, d_weights, d_bias, d_old_h, d_old_cell
Essentially removing old_cell
from the variables saved in the forward for the backward and d_gates
from the returned gradients in the backward.
I'm available to make a pull requests with the fix.