Skip to content

Commit ca9af51

Browse files
committed
TEST: Plot results in test_pkg_lnn.py
1 parent 5265a01 commit ca9af51

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

integration_tests/test_pkg_lnn.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,53 @@
11
from lnn.perceptron import init_perceptron, print_perceptron, normalize_input_vectors, Perceptron, train_dataset
2-
from lpython import i32, f64
2+
from lpdraw import Line, Circle, Display, Clear
3+
from lpython import i32, f64, Const
4+
from numpy import empty, int32
5+
6+
7+
def compute_decision_boundary(p: Perceptron, x: f64) -> f64:
8+
bias: f64 = p.weights[-1]
9+
slope: f64 = (-p.weights[0] / p.weights[1])
10+
intercept: f64 = (-bias / p.weights[1])
11+
return slope * x + intercept
12+
13+
def plot_graph(p: Perceptron, input_vectors: list[list[f64]], outputs: list[i32]):
14+
Width: Const[i32] = 500 # x-axis limits [0, 499]
15+
Height: Const[i32] = 500 # y-axis limits [0, 499]
16+
Screen: i32[Height, Width] = empty((Height, Width), dtype=int32)
17+
Clear(Height, Width, Screen)
18+
19+
x1: f64 = 2.0
20+
y1: f64 = compute_decision_boundary(p, x1)
21+
x2: f64 = -2.0
22+
y2: f64 = compute_decision_boundary(p, x2)
23+
24+
# center the graph using the following offset
25+
scale_offset: f64 = Width / 4
26+
shift_offset: f64 = Width / 2
27+
x1 *= scale_offset
28+
y1 *= scale_offset
29+
x2 *= scale_offset
30+
y2 *= scale_offset
31+
32+
# print (x1, y1, x2, y2)
33+
Line(Height, Width, Screen, i32(x1 + shift_offset), i32(y1 + shift_offset), i32(x2 + shift_offset), i32(y2 + shift_offset))
34+
35+
i: i32
36+
point_size: i32 = 5
37+
for i in range(len(input_vectors)):
38+
input_vectors[i][0] *= scale_offset
39+
input_vectors[i][1] *= scale_offset
40+
input_vectors[i][0] += shift_offset
41+
input_vectors[i][1] += shift_offset
42+
if outputs[i] == 1:
43+
x: i32 = i32(input_vectors[i][0])
44+
y: i32 = i32(input_vectors[i][1])
45+
Line(Height, Width, Screen, x - point_size, y, x + point_size, y)
46+
Line(Height, Width, Screen, x, y - point_size, x, y + point_size)
47+
else:
48+
Circle(Height, Width, Screen, i32(input_vectors[i][0]), i32(input_vectors[i][1]), f64(point_size))
49+
50+
Display(Height, Width, Screen)
351

452
def main0():
553
p: Perceptron
@@ -17,4 +65,6 @@ def main0():
1765
assert p.cur_accuracy > 50.0
1866
assert p.epochs_cnt > 1
1967

68+
plot_graph(p, input_vectors, outputs)
69+
2070
main0()

0 commit comments

Comments
 (0)