Skip to content

Commit 196f02a

Browse files
authored
Update maddpg and the report (#470)
* update trajectory_extension * update maddpg * update the experiment * update the report * update the report
1 parent 6766c8f commit 196f02a

File tree

4 files changed

+67
-21
lines changed

4 files changed

+67
-21
lines changed

docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# title: JuliaRL\_MADDPG\_KuhnPoker
33
# cover: assets/JuliaRL_MADDPG_KuhnPoker.png
44
# description: MADDPG applied to KuhnPoker
5-
# date: 2021-08-09
5+
# date: 2021-08-18
66
# author: "[Peter Chen](https://github.com/peterchen96)"
77
# ---
88

@@ -43,7 +43,7 @@ function RL.Experiment(
4343
state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
4444
),
4545
## drop the dummy action of the other agent.
46-
action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1),
46+
action_mapping = x -> length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1),
4747
)
4848
ns, na = 1, 1 # dimension of the state and action.
4949
n_players = 2 # number of players
@@ -101,9 +101,10 @@ function RL.Experiment(
101101
policy = NamedPolicy(player, deepcopy(policy)),
102102
trajectory = deepcopy(trajectory),
103103
)) for player in players(env) if player != chance_player(env)),
104+
SARTS, # traces
104105
128, # batch_size
105106
128, # update_freq
106-
0, # step_counter
107+
0, # initial update_step
107108
rng
108109
)
109110

docs/homepage/blog/ospp_report_210370190/index.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ As for updating the policy, the process is mainly the same as the [`DDPGPolicy`]
301301

302302
#### Usage
303303

304-
Here `MADDPGManager` is used for simultaneous games, or you can add an [action-related wrapper](https://juliareinforcementlearning.org/docs/rlenvs/#ReinforcementLearningEnvironments.ActionTransformedEnv-Tuple{Any}) to the sequential game to drop the dummy action of other players. And there is one [experiment](https://juliareinforcementlearning.org/docs/experiments/experiments/Policy%20Gradient/JuliaRL_MADDPG_KuhnPoker/#JuliaRL\\_MADDPG\\_KuhnPoker) `JuliaRL_MADDPG_KuhnPoker` as one usage example, which tests the algorithm on the Kuhn Poker game. Since the Kuhn Poker is one sequential game, I wrap the game just like the following:
304+
Here `MADDPGManager` is used for the environments of [`SIMULTANEOUS`](https://juliareinforcementlearning.org/docs/rlbase/#ReinforcementLearningBase.SIMULTANEOUS) and continuous action space(see the blog [Diagonal Gaussian Policies](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#stochastic-policies)), or you can add an [action-related wrapper](https://juliareinforcementlearning.org/docs/rlenvs/#ReinforcementLearningEnvironments.ActionTransformedEnv-Tuple{Any}) to the environment to ensure it can work with the algorithm. There is one [experiment](https://juliareinforcementlearning.org/docs/experiments/experiments/Policy%20Gradient/JuliaRL_MADDPG_KuhnPoker/#JuliaRL\\_MADDPG\\_KuhnPoker) `JuliaRL_MADDPG_KuhnPoker` as one usage example, which tests the algorithm on the Kuhn Poker game. Since the Kuhn Poker is one [`SEQUENTIAL`](ReinforcementLearningBase.SEQUENTIAL) game with discrete action space(see also the blog [Diagonal Gaussian Policies](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#stochastic-policies)), I wrap the environment just like the following:
305305
```Julia
306306
wrapped_env = ActionTransformedEnv(
307307
StateTransformedEnv(
@@ -310,7 +310,7 @@ wrapped_env = ActionTransformedEnv(
310310
state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
311311
),
312312
## drop the dummy action of the other agent.
313-
action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1),
313+
action_mapping = x -> length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1),
314314
)
315315
```
316316

@@ -376,9 +376,10 @@ agents = MADDPGManager(
376376
policy = NamedPolicy(player, deepcopy(policy)),
377377
trajectory = deepcopy(trajectory),
378378
)) for player in players(env) if player != chance_player(env)),
379+
SARTS, # traces
379380
128, # batch_size
380381
128, # update_freq
381-
0, # update_step
382+
0, # initial update_step
382383
rng
383384
)
384385
```
@@ -387,4 +388,4 @@ Plus on the [`stop_condition`](https://github.com/JuliaReinforcementLearning/Rei
387388

388389
\dfig{body;JuliaRL_MADDPG_KuhnPoker.png;Result of the experiment.}
389390

390-
**Note that** the current `MADDPGManager` still only works on the envs of [`MINIMAL_ACTION_SET`](https://juliareinforcementlearning.org/docs/rlbase/#ReinforcementLearningBase.MINIMAL_ACTION_SET). And since **MADDPG** is one deterministic algorithm, i.e., the state's response is one deterministic action, the Kuhn Poker game may not be suitable for testing the performance. In the next weeks, I'll update the algorithm and try to test it on other games.
391+
**Note that** since **MADDPG** is one deterministic algorithm, i.e., the state's response is one deterministic action, the Kuhn Poker game may not be suitable for testing the performance. In the next weeks, I'll update the algorithm and try to test it on other games.

src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,22 @@ function fetch!(s::BatchSampler, t::AbstractTrajectory, inds::Vector{Int})
8585
end
8686
end
8787

88-
function fetch!(s::BatchSampler{SARTS}, t::CircularArraySARTTrajectory, inds::Vector{Int})
89-
batch = NamedTuple{SARTS}((
90-
(consecutive_view(t[x], inds) for x in SART)...,
91-
consecutive_view(t[:state], inds .+ 1),
92-
))
88+
function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, inds::Vector{Int}) where {traces}
89+
if traces == SARTS
90+
batch = NamedTuple{SARTS}((
91+
(consecutive_view(t[x], inds) for x in SART)...,
92+
consecutive_view(t[:state], inds .+ 1),
93+
))
94+
elseif traces == SLARTSL
95+
batch = NamedTuple{SLARTSL}((
96+
(consecutive_view(t[x], inds) for x in SLART)...,
97+
consecutive_view(t[:state], inds .+ 1),
98+
consecutive_view(t[:legal_actions_mask], inds .+ 1),
99+
))
100+
else
101+
@error "unsupported traces $traces"
102+
end
103+
93104
if isnothing(s.cache)
94105
s.cache = map(batch) do x
95106
convert(Array, x)

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,28 @@ Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Her
66
See the paper https://arxiv.org/abs/1706.02275 for more details.
77
88
# Keyword arguments
9-
- `agents::Dict{<:Any, <:NamedPolicy{<:Agent{<:DDPGPolicy, <:AbstractTrajectory}, <:Any}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network.
9+
- `agents::Dict{<:Any, <:NamedPolicy{<:Agent{<:DDPGPolicy, <:AbstractTrajectory}, <:Any}}`, here each agent collects its own information. While updating the policy, each **critic** will assemble all agents' trajectory to update its own network.
10+
- `traces`, set to `SARTS` if you are apply to an environment of `MINIMAL_ACTION_SET`, or `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
1011
- `batch_size::Int`
1112
- `update_freq::Int`
1213
- `update_step::Int`, count the step.
1314
- `rng::AbstractRNG`.
1415
"""
1516
mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory, N<:Any} <: AbstractPolicy
1617
agents::Dict{<:N, <:Agent{<:NamedPolicy{<:P, <:N}, <:T}}
18+
traces
1719
batch_size::Int
1820
update_freq::Int
1921
update_step::Int
2022
rng::AbstractRNG
2123
end
2224

23-
# for simultaneous game with a discrete action space.
25+
# used for simultaneous environments.
2426
function::MADDPGManager)(env::AbstractEnv)
2527
while current_player(env) == chance_player(env)
2628
env |> legal_action_space |> rand |> env
2729
end
28-
Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents)
30+
Dict((player, agent.policy(env)) for (player, agent) in π.agents)
2931
end
3032

3133
function::MADDPGManager)(stage::Union{PreEpisodeStage, PostActStage}, env::AbstractEnv)
@@ -42,7 +44,7 @@ function (π::MADDPGManager)(stage::PreActStage, env::AbstractEnv, actions)
4244
end
4345

4446
# update policy
45-
update!(π)
47+
update!, env)
4648
end
4749

4850
function::MADDPGManager)(stage::PostEpisodeStage, env::AbstractEnv)
@@ -52,11 +54,11 @@ function (π::MADDPGManager)(stage::PostEpisodeStage, env::AbstractEnv)
5254
end
5355

5456
# update policy
55-
update!(π)
57+
update!, env)
5658
end
5759

5860
# update policy
59-
function RLBase.update!::MADDPGManager)
61+
function RLBase.update!::MADDPGManager, env::AbstractEnv)
6062
π.update_step += 1
6163
π.update_step % π.update_freq == 0 || return
6264

@@ -69,7 +71,7 @@ function RLBase.update!(π::MADDPGManager)
6971
temp_player = collect(keys.agents))[1]
7072
t = π.agents[temp_player].trajectory
7173
inds = rand.rng, 1:length(t), π.batch_size)
72-
batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}.batch_size), agent.trajectory, inds))
74+
batches = Dict((player, RLCore.fetch!(BatchSampler{π.traces}.batch_size), agent.trajectory, inds))
7375
for (player, agent) in π.agents)
7476

7577
# get s, a, s′ for critic
@@ -95,7 +97,8 @@ function RLBase.update!(π::MADDPGManager)
9597
)
9698

9799
for (player, agent) in π.agents
98-
p = agent.policy.policy # get DDPGPolicy struct
100+
p = agent.policy.policy # get agent's concrete DDPGPolicy.
101+
99102
A = p.behavior_actor
100103
C = p.behavior_critic
101104
Aₜ = p.target_actor
@@ -104,6 +107,28 @@ function RLBase.update!(π::MADDPGManager)
104107
γ = p.γ
105108
ρ = p.ρ
106109

110+
if π.traces == SLARTSL
111+
# Note that by default **MADDPG** is used for the environments with continuous action space, and `legal_action_space_mask` is
112+
# defined in the environments with discrete action space. So we need `env.action_mapping` to transform the actions
113+
# getting from the trajectory.
114+
@assert env isa ActionTransformedEnv
115+
116+
mask = batches[player][:next_legal_actions_mask]
117+
mu_actions, new_actions = send_to_host((mu_actions, new_actions)) # make sure that the actions on cpu.
118+
mu_l′ = Flux.batch(
119+
(begin
120+
actions = env.action_mapping(mu_actions[:, i])
121+
mask[actions[player]]
122+
end for i = 1:π.batch_size)
123+
)
124+
new_l′ = Flux.batch(
125+
(begin
126+
actions = env.action_mapping(new_actions[:, i])
127+
mask[actions[player]]
128+
end for i = 1:π.batch_size)
129+
)
130+
end
131+
107132
_device(x) = send_to_device(device(A), x)
108133

109134
# Note that here default A, C, Aₜ, Cₜ on the same device.
@@ -114,6 +139,10 @@ function RLBase.update!(π::MADDPGManager)
114139
t = _device(batches[player][:terminal])
115140

116141
qₜ = Cₜ(vcat(s′, new_actions)) |> vec
142+
if π.traces == SLARTSL
143+
mu_l′, new_l′ = _device((mu_l′, new_l′))
144+
qₜ .+= ifelse.(new_l′, 0.0f0, typemin(Float32))
145+
end
117146
y = r .+ γ .* (1 .- t) .* qₜ
118147

119148
gs1 = gradient(Flux.params(C)) do
@@ -128,7 +157,11 @@ function RLBase.update!(π::MADDPGManager)
128157
update!(C, gs1)
129158

130159
gs2 = gradient(Flux.params(A)) do
131-
loss = -mean(C(vcat(s, mu_actions)))
160+
v = C(vcat(s, mu_actions)) |> vec
161+
if π.traces == SLARTSL
162+
v .+= ifelse.(mu_l′, 0.0f0, typemin(Float32))
163+
end
164+
loss = -mean(v)
132165
ignore() do
133166
p.actor_loss = loss
134167
end

0 commit comments

Comments
 (0)