Skip to content

Errors in Custom C++ and CUDA Extensions #341

Open
@iacolippo

Description

@iacolippo

I think in the lltm_backward function in C++

auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

should be

auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/false);

I also think the class LLTMFunction inheriting from torch.autograd.Function contains two errors.

class LLTMFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = lltm.forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights, old_cell]
        ctx.save_for_backward(*variables)

        return new_h, new_cell

    @staticmethod
    def backward(ctx, grad_h, grad_cell):
        outputs = lltm.backward(
            grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
        d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs
        return d_input, d_weights, d_bias, d_old_h, d_old_cell

should be

class LLTMFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = lltm.forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights]
        ctx.save_for_backward(*variables)
        return new_h, new_cell

    @staticmethod
    def backward(ctx, grad_h, grad_cell):
        outputs = lltm.backward(
            grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors)
        d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
        return d_input, d_weights, d_bias, d_old_h, d_old_cell

Essentially removing old_cell from the variables saved in the forward for the backward and d_gates from the returned gradients in the backward.

I'm available to make a pull requests with the fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    C++Issues relating to C++ tutorialsCUDAIssues relating to CUDA

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions