Skip to content

Commit b2b0586

Browse files
committed
Curry cost functions for flexibility
1 parent 7c605fa commit b2b0586

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def is_valid_pos(cur_pos):
2020
(0, -1),
2121
(-1, 0)
2222
]
23-
path_finder = PathFinder(linear_cost, cost, is_valid_pos)
23+
path_finder = PathFinder(linear_cost(), linear_cost(), is_valid_pos)
2424
path = path_finder.find_path(moves, start, end)
2525
map_generator.view_path(map_data, path)
2626

movement_cost.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,49 @@
33
import math
44

55

6-
def linear_cost(start, end, cost=1):
6+
def linear_cost(scale=1):
77
"""
88
Manhattan distance, only linear movement is allowed
99
1010
Args:
1111
start (int, int): x and y coordinates of start point
1212
end (int, int): x and y coordinates of end point
13-
cost (int): cost of one step
13+
scale (int): scale of one step
1414
1515
Returns:
16-
Linear cost between start and end point
16+
Returns linear cost function between start and end point
1717
"""
1818

19-
delta_x = abs(start[0] - end[0])
20-
delta_y = abs(start[1] - end[1])
21-
return (delta_x + delta_y) * cost
19+
def cost(start, end):
20+
delta_x = abs(start[0] - end[0])
21+
delta_y = abs(start[1] - end[1])
22+
return (delta_x + delta_y) * scale
2223

24+
return cost
2325

24-
def euclidean_cost(start, end, cost=1):
26+
def euclidean_cost(scale=1):
2527
"""
2628
Euclidean distance, linear and diagonal movement is allowed,
2729
cost of diagonal movement is calculated using square root method
2830
2931
Args:
3032
start (int, int): x and y coordinates of start point
3133
end (int, int): x and y coordinates of end point
32-
cost (int): cost of one step
34+
scale (int): scale of one step
3335
3436
Returns:
35-
Euclidean cost between start and end point
37+
Returns euclidean cost fuction between start and end point
3638
"""
3739

38-
delta_x = abs(start[0] - end[0])
39-
delta_y = abs(start[1] - end[1])
40-
return math.sqrt(delta_x * delta_x + delta_y * delta_y) * cost
40+
def cost(start, end):
41+
delta_x = abs(start[0] - end[0])
42+
delta_y = abs(start[1] - end[1])
43+
return math.sqrt(delta_x * delta_x + delta_y * delta_y) * scale
4144

45+
return cost
4246

43-
def diagonal_cost(start, end, lin=1, diag=1):
47+
48+
def diagonal_cost(lin=1, diag=1):
4449
"""
4550
Diagonal distance, 8 directions.
4651
Linear and diagonal movement is allowed at same cost
@@ -50,35 +55,39 @@ def diagonal_cost(start, end, lin=1, diag=1):
5055
Args:
5156
start (int, int): x and y coordinates of start point
5257
end (int, int): x and y coordinates of end point
53-
lin int: cost of one linear step
54-
diag int: cost of one diagonal step
58+
lin int: scale of one linear step
59+
diag int: scale of one diagonal step
5560
5661
Returns:
57-
Diagonal cost between start and end point
62+
Returns diagonal cost function between start and end point
5863
"""
5964

60-
delta_x = abs(start[0] - end[0])
61-
delta_y = abs(start[1] - end[1])
62-
return (delta_x + delta_y) * lin + min(delta_x, delta_y) * (diag - 2 * lin)
65+
def cost(start, end):
66+
delta_x = abs(start[0] - end[0])
67+
delta_y = abs(start[1] - end[1])
68+
return (delta_x + delta_y) * lin + min(delta_x, delta_y) * (diag - 2 * lin)
69+
70+
return cost
6371

6472

65-
def scaled_cost(h_func, p_scale, *args):
73+
def scaled_cost(h_func, p_scale):
6674
"""
6775
Scales cost function based on given parameter
6876
6977
Args:
7078
h_func: cost function
7179
p_scale: scales cost function multiple times
72-
*args: arguments passed to cost function
7380
7481
Returns:
75-
Scaled value of cost cost
82+
Scaled cost function
7683
"""
7784

78-
return h_func(*args) * p_scale
85+
def cost(start, end):
86+
return h_func(start, end) * p_scale
7987

88+
return cost
8089

81-
def randomized_cost(sigma, mu, h_func, *args):
90+
def randomized_cost(sigma, mu, h_func):
8291
"""
8392
Generates random number with normal distribution based on given sigma and mu.
8493
Scales cost function by generated random number. Suggested values are
@@ -88,10 +97,12 @@ def randomized_cost(sigma, mu, h_func, *args):
8897
sigma: standard deviation in normal distribution
8998
mu: average value in normal distribution
9099
h_func: cost function
91-
*args: arguments passed to cost function
92100
93101
Returns:
94-
Randomly scaled value of cost cost
102+
Randomly scaled cost function
95103
"""
96104

97-
return h_func(*args) * random.normalvariate(mu, sigma)
105+
def cost(start, end):
106+
return h_func(start, end) * random.normalvariate(mu, sigma)
107+
108+
return cost

0 commit comments

Comments
 (0)