Skip to content

Commit 6113334

Browse files
committed
[#2] organize code a bit
1 parent fe37a66 commit 6113334

File tree

4 files changed

+92
-137
lines changed

4 files changed

+92
-137
lines changed

main.py renamed to sorting.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,31 @@
1-
import random
2-
3-
import svg
4-
5-
6-
def swap(history, idx1, idx2):
7-
8-
arr = history[-1].copy()
9-
tmp = arr[idx1]
10-
arr[idx1] = arr[idx2]
11-
arr[idx2] = tmp
12-
13-
return history + [arr]
14-
1+
def insertion_sort_history(arr):
2+
history = [arr.copy()]
153

16-
def insertion_sort(arr):
174
i = 1
18-
19-
while i < len(arr[0]):
5+
while i < len(arr):
206
j = i
21-
while j > 0 and arr[-1][j - 1] > arr[-1][j]:
22-
arr = swap(arr, j, j - 1)
7+
while j > 0 and arr[j - 1] > arr[j]:
8+
arr[j], arr[j - 1] = arr[j - 1], arr[j]
9+
history.append(arr.copy())
2310
j -= 1
2411
i += 1
2512

26-
svg.generate(arr, "insertion_sort")
13+
return history
14+
2715

16+
def bubble_sort_history(arr):
17+
history = [arr.copy()]
2818

29-
def bubble_sort(arr):
3019
flag = True
3120
while flag:
3221
flag = False
33-
for i in range(len(arr[-1]) - 1):
34-
if arr[-1][i] > arr[-1][i + 1]:
35-
arr = swap(arr, i, i + 1)
22+
for i in range(len(arr) - 1):
23+
if arr[i] > arr[i + 1]:
24+
arr[i], arr[i + 1] = arr[i + 1], arr[i]
25+
history.append(arr.copy())
3626
flag = True
3727

38-
svg.generate(arr, "bubble_sort")
28+
return history
3929

4030

4131
def partition_lomuto(arr, history, low, high):
@@ -61,10 +51,10 @@ def quicksort_lomuto(arr, history, low, high):
6151
quicksort_lomuto(arr, history, p + 1, high)
6252

6353

64-
def quicksort_lomuto_svg(arr):
54+
def quicksort_lomuto_history(arr):
6555
history = [arr.copy()]
6656
quicksort_lomuto(arr, history, 0, len(arr) - 1)
67-
svg.generate(history, "quicksort_lomuto")
57+
return history
6858

6959

7060
def partition_hoare(arr, history, low, high):
@@ -94,18 +84,7 @@ def quicksort_hoare(arr, history, low, high):
9484
quicksort_hoare(arr, history, split_index + 1, high)
9585

9686

97-
def quicksort_hoare_svg(arr):
87+
def quicksort_hoare_history(arr):
9888
history = [arr.copy()]
99-
quicksort_hoare(list_to_sort, history, 0, len(arr) - 1)
100-
svg.generate(history, "quicksort_hoare")
101-
102-
103-
if __name__ == "__main__":
104-
list_length = 80
105-
list_to_sort = list(range(list_length))
106-
random.Random(777).shuffle(list_to_sort)
107-
108-
insertion_sort([list_to_sort[:20]])
109-
bubble_sort([list_to_sort[:20]])
110-
quicksort_lomuto_svg(list_to_sort.copy())
111-
quicksort_hoare_svg(list_to_sort)
89+
quicksort_hoare(arr, history, 0, len(arr) - 1)
90+
return history

macaroni.py renamed to svg_composites.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import math
22

3+
import svg_primitives
4+
35

46
class Circle:
57
def __init__(self, x, y, radius):
@@ -56,58 +58,12 @@ def get_tangent_lines(c1, c2):
5658
return (c1t1, c2t1), (c1t2, c2t2)
5759

5860

59-
def circle_to_svg(c):
60-
return f"""
61-
<circle
62-
cx="{c.x}"
63-
cy="{c.y}"
64-
r="{c.radius}"
65-
fill="transparent"
66-
stroke="black"
67-
/>
68-
"""
69-
70-
71-
def make_line_svg(
72-
x1, y1, x2, y2, stroke_width=10, stroke_color="red", stroke_linecap="round"
73-
):
74-
return (
75-
f'<line x1="{round(x1, 4)}" y1="{round(y1, 4)}"'
76-
f' x2="{round(x2, 4)}" y2="{round(y2, 4)}"'
77-
f' stroke="{stroke_color}"'
78-
f' stroke-width="{stroke_width}"'
79-
f' stroke-linecap="{stroke_linecap}"/>'
80-
)
81-
82-
83-
def make_arc_svg(
84-
pt1,
85-
pt2,
86-
radius,
87-
arc_flag,
88-
stroke_width=10,
89-
stroke_color="red",
90-
stroke_linecap="round",
91-
):
92-
x1, y1 = pt1
93-
x2, y2 = pt2
94-
95-
return (
96-
f'<path d="M {round(x1, 4)} {round(y1, 4)} A {round(radius, 4)}'
97-
f' {round(radius, 4)} 0 0 {arc_flag} {round(x2, 4)} {round(y2, 4)}"'
98-
f' fill="transparent"'
99-
f' stroke="{stroke_color}"'
100-
f' stroke-linecap="{stroke_linecap}"'
101-
f' stroke-width="{stroke_width}"/>'
102-
)
103-
104-
10561
def choose_tangent_line(tl1, tl2):
10662
# this is pretty hacky but I have luxury of never running into edge cases
10763
return tl1 if tl1[0] < tl2[0] else tl2
10864

10965

110-
def make_double_macaroni_connection_svg(
66+
def double_macaroni(
11167
x1,
11268
y1,
11369
x2,
@@ -157,43 +113,10 @@ def make_double_macaroni_connection_svg(
157113
arc_flag = int(a.y > b.y)
158114

159115
return (
160-
make_line_svg(*p1, *p2, **{**bg_kwargs, "stroke_linecap": "round"})
161-
+ make_arc_svg(**arc1_kwargs, **bg_kwargs, arc_flag=arc_flag)
162-
+ make_arc_svg(**arc2_kwargs, **bg_kwargs, arc_flag=arc_flag)
163-
+ make_line_svg(*p1, *p2, **fg_kwargs)
164-
+ make_arc_svg(**arc1_kwargs, **fg_kwargs, arc_flag=arc_flag)
165-
+ make_arc_svg(**arc2_kwargs, **fg_kwargs, arc_flag=arc_flag)
116+
svg_primitives.line(*p1, *p2, **{**bg_kwargs, "stroke_linecap": "round"})
117+
+ svg_primitives.arc(**arc1_kwargs, **bg_kwargs, arc_flag=arc_flag)
118+
+ svg_primitives.arc(**arc2_kwargs, **bg_kwargs, arc_flag=arc_flag)
119+
+ svg_primitives.line(*p1, *p2, **fg_kwargs)
120+
+ svg_primitives.arc(**arc1_kwargs, **fg_kwargs, arc_flag=arc_flag)
121+
+ svg_primitives.arc(**arc2_kwargs, **fg_kwargs, arc_flag=arc_flag)
166122
)
167-
168-
169-
if __name__ == "__main__":
170-
inner_svg = ""
171-
inner_svg += make_double_macaroni_connection_svg(
172-
x1=60,
173-
y1=60,
174-
x2=200,
175-
y2=200,
176-
radius=50,
177-
stroke_width=30,
178-
stroke_color="red",
179-
stroke_linecap="flat",
180-
)
181-
182-
svg = ""
183-
if display_circles:
184-
svg = circle_to_svg(a) + circle_to_svg(b)
185-
186-
187-
with open("tangent_test.html", "w+") as text_file:
188-
text_file.write(
189-
f"""
190-
<!DOCTYPE html>
191-
<html>
192-
<body>
193-
<svg width="400" height="400">
194-
{inner_svg}
195-
</svg>
196-
</body>
197-
</html>
198-
"""
199-
)

svg_primitives.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
def circle(circle):
2+
return (
3+
f'<circle cx="{circle.x}" cy="{circle.y}" r="{circle.radius}"'
4+
'fill="transparent" stroke="black"/>'
5+
)
6+
7+
8+
def line(x1, y1, x2, y2, stroke_width=10, stroke_color="red", stroke_linecap="round"):
9+
return (
10+
f'<line x1="{round(x1, 4)}" y1="{round(y1, 4)}"'
11+
f' x2="{round(x2, 4)}" y2="{round(y2, 4)}"'
12+
f' stroke="{stroke_color}"'
13+
f' stroke-width="{stroke_width}"'
14+
f' stroke-linecap="{stroke_linecap}"/>'
15+
)
16+
17+
18+
def arc(
19+
pt1,
20+
pt2,
21+
radius,
22+
arc_flag,
23+
stroke_width=10,
24+
stroke_color="red",
25+
stroke_linecap="round",
26+
):
27+
x1, y1 = pt1
28+
x2, y2 = pt2
29+
30+
return (
31+
f'<path d="M {round(x1, 4)} {round(y1, 4)} A {round(radius, 4)}'
32+
f' {round(radius, 4)} 0 0 {arc_flag} {round(x2, 4)} {round(y2, 4)}"'
33+
f' fill="transparent"'
34+
f' stroke="{stroke_color}"'
35+
f' stroke-linecap="{stroke_linecap}"'
36+
f' stroke-width="{stroke_width}"/>'
37+
)

svg.py renamed to visualization.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from matplotlib import cm
44

5-
from macaroni import make_line_svg, make_double_macaroni_connection_svg
5+
import svg_primitives
6+
import svg_composites
67

78

89
def unnest_svg_list(svg_list):
@@ -25,6 +26,8 @@ def get_color(color_dict, val):
2526

2627

2728
def get_swaps_and_straights(history):
29+
"""Transform history into format more useful for drawing"""
30+
2831
# pad history at the beginning and end to make my life easier
2932
padded_history = history.copy()
3033
padded_history.insert(0, padded_history[0])
@@ -74,7 +77,7 @@ def make_straight_path(val, indices, color_dict, transform_kwargs):
7477
stroke_width = transform_kwargs["line_width"]
7578
stroke_linecap = "round"
7679

77-
return make_line_svg(*coords, stroke_width, stroke_color, stroke_linecap)
80+
return svg_primitives.line(*coords, stroke_width, stroke_color, stroke_linecap)
7881

7982

8083
def make_swap_path(
@@ -89,7 +92,7 @@ def make_swap_path(
8992
curve_radius = min_curve_radius + curve_radius_delta / distance
9093
stroke_width = transform_kwargs["line_width"]
9194

92-
return make_double_macaroni_connection_svg(
95+
return svg_composites.double_macaroni(
9396
x1,
9497
y1,
9598
x2,
@@ -138,6 +141,7 @@ def generate(
138141
):
139142

140143
# set up some params
144+
final_state = history[-1]
141145
y_offset = line_width / 2
142146
transform_kwargs = {
143147
"spacing": spacing,
@@ -146,35 +150,36 @@ def generate(
146150
"y_offset": y_offset,
147151
}
148152

149-
# compute curve stuff
153+
# compute curve params
150154
min_curve_radius = (line_height - line_width) / min_curve_radius_denominator
151155
curve_kwargs = {
152156
"min_curve_radius": min_curve_radius,
153157
"curve_radius_delta": line_height - line_width - spacing - min_curve_radius,
154158
}
155159

156160
# set up colors
157-
color_dict = get_color_dict(history[-1])
161+
color_dict = get_color_dict(final_state)
158162

159163
# get path histories
160164
swaps, straights = get_swaps_and_straights(history)
161165

166+
# compute svg dimensions
167+
num_vals = len(final_state)
168+
total_width = (num_vals) * spacing + num_vals * line_width
169+
total_height = (len(history) + 1) * line_height + 2 * y_offset
170+
162171
# actually make the svgs
163172
swap_kwargs = {
164173
"swaps": swaps,
165174
"color_dict": color_dict,
166175
"transform_kwargs": transform_kwargs,
167176
"curve_kwargs": curve_kwargs,
168177
}
178+
# ensure that left-right swaps render over right-left swaps
169179
under_swap_paths = make_swap_paths(mode="under", **swap_kwargs)
170180
over_swap_paths = make_swap_paths(mode="over", **swap_kwargs)
171181
straight_paths = make_straight_paths(straights, color_dict, transform_kwargs)
172182

173-
# compute svg dimensions
174-
num_vals = len(history[0])
175-
total_width = (num_vals) * spacing + num_vals * line_width
176-
total_height = (len(history) + 1) * line_height + 2 * y_offset
177-
178183
with open(f"{filename}.svg", "w+") as text_file:
179184
text_file.write(
180185
f'<svg role="img" height="{int(total_height)}" width="{int(total_width)}"'
@@ -184,3 +189,14 @@ def generate(
184189
+ over_swap_paths
185190
+ "</svg>\n"
186191
)
192+
193+
194+
if __name__ == "__main__":
195+
import random
196+
import sorting
197+
198+
list_length = 80
199+
list_to_sort = list(range(list_length))
200+
random.Random(777).shuffle(list_to_sort)
201+
202+
generate(sorting.quicksort_hoare_history(list_to_sort), "quicksort_hoare")

0 commit comments

Comments
 (0)