-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtask.py
77 lines (65 loc) · 1.95 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from itertools import product
from mdp_lib.domains.gridworld import GridWorld
goal_reward=50
true_belief_reward = 50
danger_reward = -10
step_cost = 0
wall_action = False
wait_action = False
init_ground_state=(0, 2)
ground_goal_state=(5, 2)
base_discount_rate = .99
base_softmax_temp = 1
obmdp_discount_rate = .99
obmdp_softmax_temp=1
belief_reward_isterminal = False
seed_trajs=None
discretized_tf=None
state_features = [
'.oooo.',
'.oppp.',
'.opccy',
'.oppc.',
'.cccc.'
]
#=============================#
# Ground MDP params #
#=============================#
mdp_params = []
feature_rewards = [dict(zip('opc', rs)) for rs in product([0, danger_reward],
repeat=3)]
mdp_codes = []
for frewards in feature_rewards:
rfc = ['o' if frewards[f] == 0 else 'x' for f in 'opc']
rfc = ''.join(rfc)
mdp_codes.append(rfc)
frewards['y'] = goal_reward
frewards['.'] = 0
for mdpc, frewards in zip(mdp_codes, feature_rewards):
params = {
'gridworld_array': state_features,
'feature_rewards': frewards,
'absorbing_states': [ground_goal_state, ],
'init_state': init_ground_state,
'wall_action': wall_action,
'step_cost': step_cost,
'wait_action': wait_action,
'discount_rate': base_discount_rate
}
mdp_params.append(params)
#===========================================#
# Observer Belief MDP params #
#===========================================#
ob_mdp_params = {
'init_ground_state': init_ground_state,
'mdp_params': mdp_params,
'mdp_codes': mdp_codes,
'MDP': GridWorld,
'base_softmax_temp': base_softmax_temp,
'true_belief_reward': true_belief_reward,
'base_policy_type': 'softmax',
'true_mdp_i': None,
'belief_reward_isterminal': False,
'discount_rate': obmdp_discount_rate,
'discretized_tf': discretized_tf
}