-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
120 lines (105 loc) · 3.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cv2
import bot3RLNav
import gym
from gym.utils.env_checker import check_env
# ----------------------------------------------------------------------
def check():
""""""
env = gym.make('bot3RLNav/World-v5', map_file="data/map01.jpg",
robot_file="data/robot.png", learning_type=3)
check_env(env)
print(env.action_space.sample())
print(env.observation_space.sample())
test_simulate = True
if test_simulate:
cv2.namedWindow("bot3")
wait = 100 # ms
for i in range(10):
obs = env.reset()
img = env.render(mode="rgb_array")
cv2.imshow("bot3", img)
cv2.waitKey(wait)
# ----------------------------------------------------------------------
def train():
""""""
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
env = gym.make('bot3RLNav/DiscreteWorld-v5', map_file="data/map01.jpg",
robot_file="data/robot.png", learning_type=1)
model = DQN("MlpPolicy", env, verbose=1, learning_rate=0.001, gamma=0.1, exploration_fraction=0.8)
model.learn(total_timesteps=80000, log_interval=5)
obs = env.reset()
name = "bot3"
cv2.namedWindow(name, cv2.WINDOW_NORMAL)
rate = 100 # frame rate in ms
count = 1000
while count > 0:
frame = env.render(mode="rgb_array")
cv2.imshow("bot3", frame)
cv2.waitKey(rate)
action, _states = model.predict(obs, deterministic=True)
# action = 5
obs, reward, done, info = env.step(action)
print(count, info, reward, env.actions[action])
if done:
print("done.")
frame = env.render(mode="rgb_array")
cv2.imshow("bot3", frame)
cv2.waitKey(rate * 20)
break
count -= 1
print()
from stable_baselines3.common.monitor import Monitor
env1 = Monitor(env)
o = evaluate_policy(model, env1, n_eval_episodes=10, render=False,
# return_episode_rewards=True
)
print(o)
env.close()
# ----------------------------------------------------------------------
def train_td3():
""""""
from stable_baselines3 import TD3
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
import numpy as np
np.set_printoptions(precision=4)
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make('bot3RLNav/World-v5', map_file="data/map01.jpg",
robot_file="data/robot.png")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1, learning_rate=0.001, gamma=0.1)
model.learn(total_timesteps=8000, log_interval=100)
obs = env.reset()
name = "bot3"
cv2.namedWindow(name, cv2.WINDOW_NORMAL)
rate = 200 # frame rate in ms
count = 100
while count > 0:
frame = env.render(mode="rgb_array")
cv2.imshow("bot3", frame)
cv2.waitKey(rate)
# action, _states = np.array([-0.5, 1]), 0
action, _states = model.predict(obs, deterministic=True)
action_ = (action + np.array([0.5, 0])) * np.array([10, 5])
obs, reward, done, info = env.step(action)
print(count, action_, f"{reward:.4f}", info)
print()
if done:
print("done.")
frame = env.render(mode="rgb_array")
cv2.imshow("bot3", frame)
cv2.waitKey(rate * 20)
break
count -= 1
print()
env1 = Monitor(env)
# o = evaluate_policy(model, env1, n_eval_episodes=10, render=False)
# print(o)
env.close()
if __name__ == '__main__':
# check()
train()
# train_td3()