Open
Description
https://arxiv.org/abs/1911.09665
In the paper, they propose calculating two losses: one for the forward pass with "clean" BN params, and another for the forward pass with adversarial BN params. Then they combine these two losses, and backprop through both BN paths at the same time (joint optimization).
Does the following look correct to you:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=2)
self.bnC = nn.BatchNorm2d(32)
self.bnA = nn.BatchNorm2d(32)
self.relu = nn.ReLU()
self.linear = nn.Linear(32*14*14, 10)
def forward(self, x, clean=True):
x = self.conv(x)
if clean:
x = self.bnC(x)
else:
x = self.bnA(x)
x = self.relu(x)
x = self.linear(x)
return x
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
for i in range(1000):
batchC, targetC = get_clean_batch()
batchA, targetA = get_adv_batch()
outputC = model(batchC, clean=True)
outputA = model(batchA, clean=False)
lossC = loss_fn(outputC, targetC)
lossA = loss_fn(outputA, targetA)
loss = lossC + lossA
optimizer.zero_grad()
loss.backward()
optimizer.step()
If so, how would you propagate clean
argument to all the blocks, especially the ones that use nn.Sequential lists?
Is there some existing AdvProp code to look at?