Skip to content

AdamP low GPU usage #1211

Answered by rwightman
hadarshavit asked this question in Q&A
Apr 6, 2022 · 1 comments · 1 reply
Discussion options

You must be logged in to vote

@hadarshavit

this,

def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
    wd = 1.
    expand_size = (-1,) + (1,) * (len(p.shape) - 1)
    for view_func in [_channel_view, _layer_view]:
        param_view = view_func(p)
        grad_view = view_func(grad)
        cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()

        # FIXME this is a problem for PyTorch XLA
        if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
            p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
            perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
            wd = wd_ratio
            retu…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@hadarshavit
Comment options

Answer selected by hadarshavit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants