This repository was archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 284
Expand file tree
/
Copy pathactor_critic.py
More file actions
76 lines (62 loc) · 2.52 KB
/
actor_critic.py
File metadata and controls
76 lines (62 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
from ..args_provider import ArgsProvider
from .policy_gradient import PolicyGradient
from .discounted_reward import DiscountedReward
from .value_matcher import ValueMatcher
from .utils import add_err
# Actor critic model.
class ActorCritic:
''' An actor critic model '''
def __init__(self):
''' Initialization of `PolicyGradient`, `DiscountedReward` and `ValueMatcher`.
Initialize the arguments needed (num_games, batchsize, value_node) and in child_providers.
'''
self.pg = PolicyGradient()
self.discounted_reward = DiscountedReward()
self.value_matcher = ValueMatcher()
self.args = ArgsProvider(
call_from = self,
define_args = [
],
more_args = ["num_games", "batchsize", "value_node"],
child_providers = [ self.pg.args, self.discounted_reward.args, self.value_matcher.args ],
)
def update(self, mi, batch, stats):
''' Actor critic model update.
Feed stats for later summarization.
Args:
mi(`ModelInterface`): mode interface used
batch(dict): batch of data. Keys in a batch:
``s``: state,
``r``: immediate reward,
``terminal``: if game is terminated
stats(`Stats`): Feed stats for later summarization.
'''
m = mi["model"]
args = self.args
value_node = self.args.value_node
T = batch["s"].size(0)
state_curr = m(batch.hist(T - 1))
self.discounted_reward.setR(state_curr[value_node].squeeze().data, stats)
err = None
for t in range(T - 2, -1, -1):
bht = batch.hist(t)
state_curr = m.forward(bht)
# go through the sample and get the rewards.
V = state_curr[value_node].squeeze()
R = self.discounted_reward.feed(
dict(r=batch["r"][t], terminal=batch["terminal"][t]),
stats=stats)
policy_err = self.pg.feed(R-V.data, state_curr, bht, stats, old_pi_s=bht)
err = add_err(err, policy_err)
err = add_err(err, self.value_matcher.feed({ value_node: V, "target" : R}, stats))
stats["cost"].feed(err.item() / (T - 1))
err.backward()