diff --git a/src/stlcg.py b/src/stlcg.py index 504065f..22c7538 100644 --- a/src/stlcg.py +++ b/src/stlcg.py @@ -5,9 +5,10 @@ import IPython # Assume inputs are already reversed. -# Nikos TODO: +# TODO: # - Edit the "forward" methods to account for the possibility that they are receiving an input of type Expression # - Run tests to ensure that "Expression" correctly overrides operators +# - Make a test for each temporal operator, and make sure that they all produce the expected output for at least one example trace # - Implement log-barrier LARGE_NUMBER = 1E4 @@ -178,7 +179,7 @@ def _next_function(self): return [self.subformula] def forward(self, x, scale=0): - if isintance(x, Expression): + if isinstance(x, Expression): return self.robustness_trace(x.value, scale) else: return self.robustness_trace(x, scale) @@ -285,7 +286,7 @@ def _next_function(self): return [self.name, self.c] def forward(self, x, scale=1): - if isinstance(x, expression): + if isinstance(x, Expression): return self.robustness_trace(x.value, scale) else: return self.robustness_trace(x, scale) @@ -356,7 +357,10 @@ def _next_function(self): return [self.subformula] def forward(self, x, scale=1): - return self.robustness_trace(x, scale) + if isinstance(x, Expression): + return self.robustness_trace(x.value, scale) + else: + return self.robustness_trace(x, scale) def __str__(self): return "¬(" + str(self.subformula) + ")" @@ -399,7 +403,14 @@ def _next_function(self): return [self.subformula1, self.subformula2] def forward(self, trace1, trace2, scale=0): - return self.robustness_trace(trace1, trace2, scale) + if isinstance(trace1, Expression) and isinstance(trace2, Expression): + return self.robustness_trace(trace1.value, trace2.value, scale) + elif isinstance(trace1, Expression): + return self.robustness_trace(trace1.value, trace2, scale) + elif isinstance(trace2, Expression): + return self.robustness_trace(trace1, trace2.value, scale) + else: + return self.robustness_trace(trace1, trace2, scale) def __str__(self): return "(" + str(self.subformula1) + ") ∧ (" + str(self.subformula2) + ")" @@ -442,7 +453,14 @@ def _next_function(self): return [self.subformula1, self.subformula2] def forward(self, trace1, trace2, scale=0): - return self.robustness_trace(trace1, trace2, scale) + if isinstance(trace1, Expression) and isinstance(trace2, Expression): + return self.robustness_trace(trace1.value, trace2.value, scale) + elif isinstance(trace1, Expression): + return self.robustness_trace(trace1.value, trace2, scale) + elif isinstance(trace2, Expression): + return self.robustness_trace(trace1, trace2.value, scale) + else: + return self.robustness_trace(trace1, trace2, scale) def __str__(self): return "(" + str(self.subformula1) + ") ∨ (" + str(self.subformula2) + ")" @@ -494,7 +512,14 @@ def _next_function(self): return [self.subformula1, self.subformula2] def forward(self, trace1, trace2, scale=0): - return self.robustness_trace(trace1, trace2, scale) + if isinstance(trace1, Expression) and isinstance(trace2, Expression): + return self.robustness_trace(trace1.value, trace2.value, scale) + elif isinstance(trace1, Expression): + return self.robustness_trace(trace1.value, trace2, scale) + elif isinstance(trace2, Expression): + return self.robustness_trace(trace1, trace2.value, scale) + else: + return self.robustness_trace(trace1, trace2, scale) def __str__(self): return "(" + str(self.subformula1) + ")" + " U " + "(" + str(self.subformula2) + ")" @@ -545,7 +570,14 @@ def _next_function(self): return [self.subformula1, self.subformula2] def forward(self, trace1, trace2, scale=0): - return self.robustness_trace(trace1, trace2, scale) + if isinstance(trace1, Expression) and isinstance(trace2, Expression): + return self.robustness_trace(trace1.value, trace2.value, scale) + elif isinstance(trace1, Expression): + return self.robustness_trace(trace1.value, trace2, scale) + elif isinstance(trace2, Expression): + return self.robustness_trace(trace1, trace2.value, scale) + else: + return self.robustness_trace(trace1, trace2, scale) def __str__(self): return "(" + str(self.subformula1) + ")" + " T " + "(" + str(self.subformula2) + ")" @@ -607,7 +639,7 @@ def __truediv__(a, b): def __lt__(lhs, rhs): if isinstance(lhs, Expression) and isinstance(rhs, Expression): return LessThan(lhs.value, rhs.value) - elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)): + elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)): # This case cannot occur. If neither is an Expression, why are you calling this method? raise Exception('What are you doing?') elif not isinstance(rhs, Expression): @@ -621,7 +653,7 @@ def __le__(lhs, rhs): def __gt__(lhs, rhs): if isinstance(lhs, Expression) and isinstance(rhs, Expression): return GreaterThan(lhs.value, rhs.value) - elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)): + elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)): # This case cannot occur. If neither is an Expression, why are you calling this method? raise Exception('What are you doing?') elif not isinstance(rhs, Expression): @@ -635,7 +667,7 @@ def __ge__(lhs, rhs): def __eq__(lhs, rhs): if isinstance(lhs, Expression) and isinstance(rhs, Expression): return Equal(lhs.value, rhs.value) - elif (not isinstance(lhs, Expression)) and (not isintance(rhs, Expression)): + elif (not isinstance(lhs, Expression)) and (not isinstance(rhs, Expression)): # This case cannot occur. If neither is an Expression, why are you calling this method? raise Exception('What are you doing?') elif not isinstance(rhs, Expression): diff --git a/stlcg demo.ipynb b/stlcg demo.ipynb index 194291f..e6ab2c8 100644 --- a/stlcg demo.ipynb +++ b/stlcg demo.ipynb @@ -116,34 +116,34 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4904731984\n", + "4755286112\n", "\n", "GreaterThan\n", "x >= 4.0\n", "\n", - "\n", + "\n", "\n", - "4508003216\n", + "4358179728\n", "\n", "x\n", "\n", - "\n", + "\n", "\n", - "4508003216->4904731984\n", + "4358179728->4755286112\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772160\n", + "4755294248\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772160->4904731984\n", + "4755294248->4755286112\n", "\n", "\n", "\n", @@ -151,7 +151,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -182,34 +182,34 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4904732040\n", + "4755286168\n", "\n", "LessThan\n", "w <= 4.0\n", "\n", - "\n", + "\n", "\n", - "4508683096\n", + "4358885480\n", "\n", "w\n", "\n", - "\n", + "\n", "\n", - "4508683096->4904732040\n", + "4358885480->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772232\n", + "4755294320\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772232->4904732040\n", + "4755294320->4755286168\n", "\n", "\n", "\n", @@ -217,7 +217,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -248,34 +248,34 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4904732320\n", + "4755286224\n", "\n", "Equal\n", "x = 4.0\n", "\n", - "\n", + "\n", "\n", - "4508003216\n", + "4358179728\n", "\n", "x\n", "\n", - "\n", + "\n", "\n", - "4508003216->4904732320\n", + "4358179728->4755286224\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772232\n", + "4755294320\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772232->4904732320\n", + "4755294320->4755286224\n", "\n", "\n", "\n", @@ -283,7 +283,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -314,97 +314,97 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4904732376\n", + "4755286560\n", "\n", "Always\n", "◻ [0, inf]( (w <= 4.0) ∧ (x >= 4.0) )\n", "\n", - "\n", + "\n", "\n", - "4904732096\n", + "4755286056\n", "\n", "And\n", "(w <= 4.0) ∧ (x >= 4.0)\n", "\n", - "\n", + "\n", "\n", - "4904732096->4904732376\n", + "4755286056->4755286560\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904732040\n", + "4755286168\n", "\n", "LessThan\n", "w <= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904732040->4904732096\n", + "4755286168->4755286056\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508683096\n", + "4358885480\n", "\n", "w\n", "\n", - "\n", + "\n", "\n", - "4508683096->4904732040\n", + "4358885480->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772232\n", + "4755294320\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772232->4904732040\n", + "4755294320->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904731984\n", + "4755286112\n", "\n", "GreaterThan\n", "x >= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904731984->4904732096\n", + "4755286112->4755286056\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508003216\n", + "4358179728\n", "\n", "x\n", "\n", - "\n", + "\n", "\n", - "4508003216->4904731984\n", + "4358179728->4755286112\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772160\n", + "4755294248\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772160->4904731984\n", + "4755294248->4755286112\n", "\n", "\n", "\n", @@ -412,7 +412,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -443,97 +443,97 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4904732264\n", + "4755285776\n", "\n", "Eventually\n", "♢ [0, inf]( (w <= 4.0) ∧ (x >= 4.0) )\n", "\n", - "\n", + "\n", "\n", - "4904732096\n", + "4755286056\n", "\n", "And\n", "(w <= 4.0) ∧ (x >= 4.0)\n", "\n", - "\n", + "\n", "\n", - "4904732096->4904732264\n", + "4755286056->4755285776\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904732040\n", + "4755286168\n", "\n", "LessThan\n", "w <= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904732040->4904732096\n", + "4755286168->4755286056\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508683096\n", + "4358885480\n", "\n", "w\n", "\n", - "\n", + "\n", "\n", - "4508683096->4904732040\n", + "4358885480->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772232\n", + "4755294320\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772232->4904732040\n", + "4755294320->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904731984\n", + "4755286112\n", "\n", "GreaterThan\n", "x >= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904731984->4904732096\n", + "4755286112->4755286056\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508003216\n", + "4358179728\n", "\n", "x\n", "\n", - "\n", + "\n", "\n", - "4508003216->4904731984\n", + "4358179728->4755286112\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772160\n", + "4755294248\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772160->4904731984\n", + "4755294248->4755286112\n", "\n", "\n", "\n", @@ -541,7 +541,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -572,97 +572,97 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4544569184\n", + "4755287008\n", "\n", "Until\n", "(x >= 4.0) U (◻ [0, inf]( w <= 4.0 ))\n", "\n", - "\n", + "\n", "\n", - "4904731984\n", + "4755286112\n", "\n", "GreaterThan\n", "x >= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904731984->4544569184\n", + "4755286112->4755287008\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508003216\n", + "4358179728\n", "\n", "x\n", "\n", - "\n", + "\n", "\n", - "4508003216->4904731984\n", + "4358179728->4755286112\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772160\n", + "4755294248\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772160->4904731984\n", + "4755294248->4755286112\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904648656\n", + "4755286840\n", "\n", "Always\n", "◻ [0, inf]( w <= 4.0 )\n", "\n", - "\n", + "\n", "\n", - "4904648656->4544569184\n", + "4755286840->4755287008\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904732040\n", + "4755286168\n", "\n", "LessThan\n", "w <= 4.0\n", "\n", - "\n", + "\n", "\n", - "4904732040->4904648656\n", + "4755286168->4755286840\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508683096\n", + "4358885480\n", "\n", "w\n", "\n", - "\n", + "\n", "\n", - "4508683096->4904732040\n", + "4358885480->4755286168\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772232\n", + "4755294320\n", "\n", "4.0\n", "\n", - "\n", + "\n", "\n", - "4904772232->4904732040\n", + "4755294320->4755286168\n", "\n", "\n", "\n", @@ -670,7 +670,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 9, @@ -765,78 +765,78 @@ "\n", "%3\n", "\n", - "\n", + "\n", "\n", - "4544569184\n", + "4755284880\n", "\n", "Then\n", "(w <= 1.0) T (w >= 6.0)\n", "\n", - "\n", + "\n", "\n", - "4903839952\n", + "4755285440\n", "\n", "LessThan\n", "w <= 1.0\n", "\n", - "\n", + "\n", "\n", - "4903839952->4544569184\n", + "4755285440->4755284880\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4508683096\n", + "4358885480\n", "\n", "w\n", "\n", - "\n", + "\n", "\n", - "4508683096->4903839952\n", + "4358885480->4755285440\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4903839280\n", + "4755284992\n", "\n", "GreaterThan\n", "w >= 6.0\n", "\n", - "\n", + "\n", "\n", - "4508683096->4903839280\n", + "4358885480->4755284992\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904771656\n", + "4754943360\n", "\n", "1.0\n", "\n", - "\n", + "\n", "\n", - "4904771656->4903839952\n", + "4754943360->4755285440\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4903839280->4544569184\n", + "4755284992->4755284880\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "4904772088\n", + "4755294176\n", "\n", "6.0\n", "\n", - "\n", + "\n", "\n", - "4904772088->4903839280\n", + "4755294176->4755284992\n", "\n", "\n", "\n", @@ -844,7 +844,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 13, @@ -1039,20 +1039,20 @@ "iteration: 159 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 160 -- loss: 0.075 ---- c:3.984 ---- d:2.016\n", "iteration: 161 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", - "iteration: 162 -- loss: 0.074 ---- c:4.016 ---- d:1.984\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "iteration: 162 -- loss: 0.074 ---- c:4.016 ---- d:1.984\n", "iteration: 163 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 164 -- loss: 0.073 ---- c:3.984 ---- d:2.016\n", "iteration: 165 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 166 -- loss: 0.073 ---- c:4.016 ---- d:1.984\n", "iteration: 167 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 168 -- loss: 0.072 ---- c:3.984 ---- d:2.016\n", - "iteration: 169 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", + "iteration: 169 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "iteration: 170 -- loss: 0.072 ---- c:4.015 ---- d:1.985\n", "iteration: 171 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 172 -- loss: 0.071 ---- c:3.985 ---- d:2.015\n", @@ -1193,13 +1193,7 @@ "iteration: 307 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 308 -- loss: 0.037 ---- c:3.992 ---- d:2.008\n", "iteration: 309 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", - "iteration: 310 -- loss: 0.037 ---- c:4.008 ---- d:1.992\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "iteration: 310 -- loss: 0.037 ---- c:4.008 ---- d:1.992\n", "iteration: 311 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 312 -- loss: 0.036 ---- c:3.992 ---- d:2.008\n", "iteration: 313 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", @@ -1227,7 +1221,13 @@ "iteration: 335 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 336 -- loss: 0.031 ---- c:3.993 ---- d:2.007\n", "iteration: 337 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", - "iteration: 338 -- loss: 0.030 ---- c:4.007 ---- d:1.993\n", + "iteration: 338 -- loss: 0.030 ---- c:4.007 ---- d:1.993\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "iteration: 339 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 340 -- loss: 0.030 ---- c:3.993 ---- d:2.007\n", "iteration: 341 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", @@ -1348,13 +1348,7 @@ "iteration: 456 -- loss: 0.007 ---- c:3.998 ---- d:2.002\n", "iteration: 457 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 458 -- loss: 0.007 ---- c:4.002 ---- d:1.998\n", - "iteration: 459 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "iteration: 459 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 460 -- loss: 0.007 ---- c:3.998 ---- d:2.002\n", "iteration: 461 -- loss: 0.000 ---- c:4.000 ---- d:2.000\n", "iteration: 462 -- loss: 0.006 ---- c:4.002 ---- d:1.998\n", @@ -1415,6 +1409,13 @@ " d -= learning_rate * d.grad\n", " d.grad.zero_()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {