Skip to content

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Oct 18, 2024

Fixes #594
Fixes #496
Fixes #517

#584 broke main for torch with older versions, e.g. torch==2.2.0. There have been similar attempts at #501 and #560
The main difference is to create a wrapper around the decorator that passes the new kwarg if necessary and changes the import depending on what torch version we have (based on if amp has custom_(fwd|bwd)). I moved it to the utils folder, open to change the structure.

P.S. Verified the fix with torch=2.2.0 (older version who doesn't have custom_(fwd|bwd) in amp) and torch=2.5.0 (newer version who does have custom_(fwd|bwd) in amp).

@naromero77amd
Copy link

If possible, it would be good to maintain backwards compatibility with the last two versions of PyTorch. It would like to see this PR. land.

@tridao tridao merged commit 83a5c90 into state-spaces:main Oct 25, 2024
@vasqu vasqu deleted the fix-fwd-bwd-for-older-torch branch October 25, 2024 22:11
@KokeCacao
Copy link
Contributor

KokeCacao commented Oct 26, 2024

In older version, this expression is allowed: @custom_fwd(cast_inputs=torch.float32)

Would it be better to change it to the following? See #608

def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
    def decorator(*args, **kwargs):
        if cuda_amp_deprecated:
            kwargs["device_type"] = "cuda"
        return dec(*args, **kwargs)
    return decorator


if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
    deprecated = True
    from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
else:
    deprecated = False
    from torch.cuda.amp import custom_fwd, custom_bwd

custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants