-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadversarial.py
33 lines (31 loc) · 1.11 KB
/
adversarial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
FUN = torch.nn.functional
def FGSM(model, x, target, eps, device='cuda'):
for i in model.parameters():
i.requires_grad = False
x = x.to(device=device); target = target.to(device=device)
x.requires_grad = True
pred = model(x)
if pred == target:
return x.clone().detach()
loss = FUN.nll_loss(pred, target)
x.grad = None
loss.backward()
data_grad = x.grad
adv_ex = torch.clamp(x - eps * torch.sign(data_grad), 0, 1)
return adv_ex.detach()
def PGD(model, x, target, eps, iters=30, rlr=1, device='cuda'):
for i in model.parameters():
i.requires_grad = False
x = x.to(device=device); target = target.to(device=device)
adv_ex = x.clone().detach()
for i in range(iters):
adv_ex.grad = None
adv_ex.requires_grad = True
pred = model(adv_ex)
if pred == target:
return adv_ex.detach()
loss = FUN.nll_loss(pred, target)
loss.backward()
adv_ex = torch.clamp(adv_ex - eps * rlr * adv_ex.grad / torch.max(torch.abs(adv_ex.grad)), x - eps, x + eps).detach()
return adv_ex.detach()