Skip to content

Commit

Permalink
Work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nikos-tri committed Aug 12, 2019
1 parent 8d0213e commit 1d9d94d
Show file tree
Hide file tree
Showing 2 changed files with 1,331 additions and 22 deletions.
153 changes: 142 additions & 11 deletions src/stlcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import IPython
# Assume inputs are already reversed.

# Nikos TODO:
# - Edit the "forward" methods to account for the possibility that they are receiving an input of type Expression
# - Run tests to ensure that "Expression" correctly overrides operators
# - Implement log-barrier

LARGE_NUMBER = 1E4

class Maxish(torch.nn.Module):
Expand All @@ -17,11 +22,17 @@ def forward(self, x, scale, dim=1):
The default is
x is of size [batch_size, max_dim, x_dim]
if scale <= 0, then the true max is used, otherwise, the softmax is used.
'''
if scale > 0:
return (torch.softmax(x*scale, dim=dim)*x).sum(dim, keepdim=True)
'''
if isinstance(x, Expression):
if scale > 0:
return (torch.softmax(x.value*scale, dim=dim)*x.value).sum(dim, keepdim=True)
else:
return x.value.max(dim, keepdim=True)[0]
else:
return x.max(dim, keepdim=True)[0]
if scale > 0:
return (torch.softmax(x*scale, dim=dim)*x).sum(dim, keepdim=True)
else:
return x.max(dim, keepdim=True)[0]

def _next_function(self):
# next function is actually input (traverses the graph backwards)
Expand All @@ -38,10 +49,16 @@ def forward(self, x, scale, dim=1):
x is of size [batch_size, max_dim, ...]
if scale <= 0, then the true min is used, otherwise, the softmin is used.
'''
if scale > 0:
return (torch.softmax(-x*scale, dim=dim)*x).sum(dim, keepdim=True)
if isinstance(x, Expression):
if scale > 0:
return (torch.softmax(-x.value*scale, dim=dim)*x.value).sum(dim, keepdim=True)
else:
return x.value.min(dim, keepdim=True)[0]
else:
return x.min(dim, keepdim=True)[0]
if scale > 0:
return (torch.softmax(-x*scale, dim=dim)*x).sum(dim, keepdim=True)
else:
return x.min(dim, keepdim=True)[0]

def _next_function(self):
# next function is actually input (traverses the graph backwards)
Expand Down Expand Up @@ -161,7 +178,10 @@ def _next_function(self):
return [self.subformula]

def forward(self, x, scale=0):
return self.robustness_trace(x, scale)
if isintance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)


class Always(Temporal_Operator):
Expand Down Expand Up @@ -228,7 +248,10 @@ def _next_function(self):
return [self.name, self.c]

def forward(self, x, scale=1):
return self.robustness_trace(x, scale)
if isinstance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)

def __str__(self):
return self.name + " <= " + tensor_to_str(self.c)
Expand Down Expand Up @@ -262,7 +285,10 @@ def _next_function(self):
return [self.name, self.c]

def forward(self, x, scale=1):
return self.robustness_trace(x, scale)
if isinstance(x, expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)

def __str__(self):
return self.name + " >= " + tensor_to_str(self.c)
Expand Down Expand Up @@ -295,7 +321,10 @@ def _next_function(self):
return [self.name, self.c]

def forward(self, x, scale=1):
return self.robustness_trace(x, scale)
if isinstance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)

def __str__(self):
return self.name + " = " + tensor_to_str(self.c)
Expand Down Expand Up @@ -521,6 +550,108 @@ def forward(self, trace1, trace2, scale=0):
def __str__(self):
return "(" + str(self.subformula1) + ")" + " T " + "(" + str(self.subformula2) + ")"

class Expression(torch.nn.Module):
'''
Wraps a pytorch arithmetic operation, so that we can intercept and overload comparison operators.
'''
def __init__( value ):
super(Expression,self).__init__()
self.value = value

def __neg__(self):
return Expression(-self.value)

def __add__(self, other):
if isinstance(other, Expression):
return Expression(self.value + other.value)
else:
return Expression(self.value + other)

def __radd__(self, other):
return self.__add__(other)
# No need for the case when "other" is an Expression, since that
# case will be handled by the regular add

def __sub__(self, other):
if isinstance(other, Expression):
return Expression(self.value - other.value)
else:
return Expression(self.value - other)

def __rsub__(self, other):
return Expression(other - self.value)
# No need for the case when "other" is an Expression, since that
# case will be handled by the regular sub

def __mul__(self, other):
if isinstance(other, Expression):
return Expression(self.value * other.value)
else:
return Expression(self.value * other)

def __rmul__(self, other):
return self.__mul__(other)

def __truediv__(a, b):
# This is the new form required by Python 3
numerator = a
denominator = b
if isinstance(numerator, Expression):
numerator = numerator.value
if isinstance(denominator, Expression):
denominator = denominator.value

return Expression(numerator/denominator)

# Comparators
def __lt__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return LessThan(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
return LessThan(lhs.value, rhs)
elif not isinstance(lhs, Expression):
return LessThan(lhs, rhs.value)

def __le__(lhs, rhs):
raise NotImplementedError("Not supported yet")

def __gt__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return GreaterThan(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
return GreaterThan(lhs.value, rhs)
elif not isinstance(lhs, Expression):
return GreaterThan(lhs, rhs.value)

def __ge__(lhs, rhs):
raise NotImplementedError("Not supported yet")

def __eq__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return Equal(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
return Equal(lhs.value, rhs)
elif not isinstance(lhs, Expression):
return Comparison('==', lhs, rhs.value)

def __ne__(lhs, rhs):
raise NotImplementedError("Not supported yet")

def __str__(self):
return str(self.value)

def __repr__(self):
return repr(self.value)


# class STLModel(torch.nn.Module):
# def __init__(self, inner, outer):
Expand Down
Loading

0 comments on commit 1d9d94d

Please sign in to comment.