Skip to content

Commit

Permalink
Added Expression class
Browse files Browse the repository at this point in the history
  • Loading branch information
nikos-tri committed Aug 12, 2019
1 parent 1d9d94d commit aee7570
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 185 deletions.
54 changes: 43 additions & 11 deletions src/stlcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import IPython
# Assume inputs are already reversed.

# Nikos TODO:
# TODO:
# - Edit the "forward" methods to account for the possibility that they are receiving an input of type Expression
# - Run tests to ensure that "Expression" correctly overrides operators
# - Make a test for each temporal operator, and make sure that they all produce the expected output for at least one example trace
# - Implement log-barrier

LARGE_NUMBER = 1E4
Expand Down Expand Up @@ -178,7 +179,7 @@ def _next_function(self):
return [self.subformula]

def forward(self, x, scale=0):
if isintance(x, Expression):
if isinstance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)
Expand Down Expand Up @@ -285,7 +286,7 @@ def _next_function(self):
return [self.name, self.c]

def forward(self, x, scale=1):
if isinstance(x, expression):
if isinstance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)
Expand Down Expand Up @@ -356,7 +357,10 @@ def _next_function(self):
return [self.subformula]

def forward(self, x, scale=1):
return self.robustness_trace(x, scale)
if isinstance(x, Expression):
return self.robustness_trace(x.value, scale)
else:
return self.robustness_trace(x, scale)

def __str__(self):
return "¬(" + str(self.subformula) + ")"
Expand Down Expand Up @@ -399,7 +403,14 @@ def _next_function(self):
return [self.subformula1, self.subformula2]

def forward(self, trace1, trace2, scale=0):
return self.robustness_trace(trace1, trace2, scale)
if isinstance(trace1, Expression) and isinstance(trace2, Expression):
return self.robustness_trace(trace1.value, trace2.value, scale)
elif isinstance(trace1, Expression):
return self.robustness_trace(trace1.value, trace2, scale)
elif isinstance(trace2, Expression):
return self.robustness_trace(trace1, trace2.value, scale)
else:
return self.robustness_trace(trace1, trace2, scale)

def __str__(self):
return "(" + str(self.subformula1) + ") ∧ (" + str(self.subformula2) + ")"
Expand Down Expand Up @@ -442,7 +453,14 @@ def _next_function(self):
return [self.subformula1, self.subformula2]

def forward(self, trace1, trace2, scale=0):
return self.robustness_trace(trace1, trace2, scale)
if isinstance(trace1, Expression) and isinstance(trace2, Expression):
return self.robustness_trace(trace1.value, trace2.value, scale)
elif isinstance(trace1, Expression):
return self.robustness_trace(trace1.value, trace2, scale)
elif isinstance(trace2, Expression):
return self.robustness_trace(trace1, trace2.value, scale)
else:
return self.robustness_trace(trace1, trace2, scale)

def __str__(self):
return "(" + str(self.subformula1) + ") ∨ (" + str(self.subformula2) + ")"
Expand Down Expand Up @@ -494,7 +512,14 @@ def _next_function(self):
return [self.subformula1, self.subformula2]

def forward(self, trace1, trace2, scale=0):
return self.robustness_trace(trace1, trace2, scale)
if isinstance(trace1, Expression) and isinstance(trace2, Expression):
return self.robustness_trace(trace1.value, trace2.value, scale)
elif isinstance(trace1, Expression):
return self.robustness_trace(trace1.value, trace2, scale)
elif isinstance(trace2, Expression):
return self.robustness_trace(trace1, trace2.value, scale)
else:
return self.robustness_trace(trace1, trace2, scale)

def __str__(self):
return "(" + str(self.subformula1) + ")" + " U " + "(" + str(self.subformula2) + ")"
Expand Down Expand Up @@ -545,7 +570,14 @@ def _next_function(self):
return [self.subformula1, self.subformula2]

def forward(self, trace1, trace2, scale=0):
return self.robustness_trace(trace1, trace2, scale)
if isinstance(trace1, Expression) and isinstance(trace2, Expression):
return self.robustness_trace(trace1.value, trace2.value, scale)
elif isinstance(trace1, Expression):
return self.robustness_trace(trace1.value, trace2, scale)
elif isinstance(trace2, Expression):
return self.robustness_trace(trace1, trace2.value, scale)
else:
return self.robustness_trace(trace1, trace2, scale)

def __str__(self):
return "(" + str(self.subformula1) + ")" + " T " + "(" + str(self.subformula2) + ")"
Expand Down Expand Up @@ -607,7 +639,7 @@ def __truediv__(a, b):
def __lt__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return LessThan(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
Expand All @@ -621,7 +653,7 @@ def __le__(lhs, rhs):
def __gt__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return GreaterThan(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
Expand All @@ -635,7 +667,7 @@ def __ge__(lhs, rhs):
def __eq__(lhs, rhs):
if isinstance(lhs, Expression) and isinstance(rhs, Expression):
return Equal(lhs.value, rhs.value)
elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)):
elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)):
# This case cannot occur. If neither is an Expression, why are you calling this method?
raise Exception('What are you doing?')
elif not isinstance(rhs, Expression):
Expand Down
Loading

0 comments on commit aee7570

Please sign in to comment.