diff --git a/src/stlcg.py b/src/stlcg.py index 6b9ef94..da279d2 100644 --- a/src/stlcg.py +++ b/src/stlcg.py @@ -744,8 +744,8 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False, minish = Minish() maxish = Maxish() LHS = trace2.unsqueeze(-1).repeat([1, 1, 1,trace2.shape[1]]).permute(0, 3, 2, 1) - RHS = torch.ones_like(LHS)*-LARGE_NUMBER if interval == None: + RHS = torch.ones_like(LHS)*-LARGE_NUMBER for i in range(trace2.shape[1]): RHS[:,i:,:,i] = Alw(trace1[:,i:,:]) return maxish( @@ -758,7 +758,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False, for i in range(b,trace2.shape[1]): A = trace2[:,i-b:i-a+1,:].unsqueeze(-1) relevant = trace1[:,:i+1,:] - B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].unsqueeze(-1) + B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].flip(1).unsqueeze(-1) RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed)) return torch.cat(RHS, dim=1); else: @@ -767,7 +767,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False, for i in range(a,trace2.shape[1]): A = trace2[:,:i-a+1,:].unsqueeze(-1) relevant = trace1[:,:i+1,:] - B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].unsqueeze(-1) + B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].flip(1).unsqueeze(-1) RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed)) return torch.cat(RHS, dim=1); @@ -823,7 +823,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False, for i in range(b,trace2.shape[1]): A = trace2[:,i-b:i-a+1,:].unsqueeze(-1) relevant = trace1[:,:i+1,:] - B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].unsqueeze(-1) + B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].flip(1).unsqueeze(-1) RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed)) return torch.cat(RHS, dim=1); else: @@ -832,7 +832,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False, for i in range(a,trace2.shape[1]): A = trace2[:,:i-a+1,:].unsqueeze(-1) relevant = trace1[:,:i+1,:] - B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].unsqueeze(-1) + B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].flip(1).unsqueeze(-1) RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed)) return torch.cat(RHS, dim=1); # [batch_size, time_dim, x_dim]