Skip to content

Commit

Permalink
updates with more comments and descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
karenl7 committed Aug 3, 2021
2 parents b60d4bd + ce882d8 commit 883aa8f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/stlcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False,
minish = Minish()
maxish = Maxish()
LHS = trace2.unsqueeze(-1).repeat([1, 1, 1,trace2.shape[1]]).permute(0, 3, 2, 1)
RHS = torch.ones_like(LHS)*-LARGE_NUMBER
if interval == None:
RHS = torch.ones_like(LHS)*-LARGE_NUMBER
for i in range(trace2.shape[1]):
RHS[:,i:,:,i] = Alw(trace1[:,i:,:])
return maxish(
Expand All @@ -758,7 +758,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False,
for i in range(b,trace2.shape[1]):
A = trace2[:,i-b:i-a+1,:].unsqueeze(-1)
relevant = trace1[:,:i+1,:]
B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].unsqueeze(-1)
B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].flip(1).unsqueeze(-1)
RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed))
return torch.cat(RHS, dim=1);
else:
Expand All @@ -767,7 +767,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False,
for i in range(a,trace2.shape[1]):
A = trace2[:,:i-a+1,:].unsqueeze(-1)
relevant = trace1[:,:i+1,:]
B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].unsqueeze(-1)
B = Alw(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].flip(1).unsqueeze(-1)
RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed))
return torch.cat(RHS, dim=1);

Expand Down Expand Up @@ -823,7 +823,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False,
for i in range(b,trace2.shape[1]):
A = trace2[:,i-b:i-a+1,:].unsqueeze(-1)
relevant = trace1[:,:i+1,:]
B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].unsqueeze(-1)
B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:b+1,:].flip(1).unsqueeze(-1)
RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed))
return torch.cat(RHS, dim=1);
else:
Expand All @@ -832,7 +832,7 @@ def robustness_trace(self, inputs, pscale=1, scale=-1, keepdim=True, agm=False,
for i in range(a,trace2.shape[1]):
A = trace2[:,:i-a+1,:].unsqueeze(-1)
relevant = trace1[:,:i+1,:]
B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].unsqueeze(-1)
B = Ev(relevant.flip(1), scale=scale, keepdim=keepdim, distributed=distributed)[:,a:,:].flip(1).unsqueeze(-1)
RHS.append(maxish(minish(torch.cat([A,B], dim=-1), dim=-1, scale=scale, keepdim=False, distributed=distributed), dim=1, scale=scale, keepdim=keepdim, distributed=distributed))
return torch.cat(RHS, dim=1); # [batch_size, time_dim, x_dim]

Expand Down

0 comments on commit 883aa8f

Please sign in to comment.