This repository was archived by the owner on Mar 8, 2022. It is now read-only.
forked from facebookresearch/impact-driven-exploration
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
48 lines (34 loc) · 1.65 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
def compute_baseline_loss(advantages):
return 0.5 * torch.sum(torch.mean(advantages**2, dim=1))
def compute_entropy_loss(logits):
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1)
return -torch.sum(torch.mean(entropy_per_timestep, dim=1))
def compute_policy_gradient_loss(logits, actions, advantages):
cross_entropy = F.nll_loss(
F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
target=torch.flatten(actions, 0, 1),
reduction='none')
cross_entropy = cross_entropy.view_as(advantages)
advantages.requires_grad = False
policy_gradient_loss_per_timestep = cross_entropy * advantages
return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1))
def compute_forward_dynamics_loss(pred_next_emb, next_emb):
forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2)
return torch.sum(torch.mean(forward_dynamics_loss, dim=1))
def compute_inverse_dynamics_loss(pred_actions, true_actions):
inverse_dynamics_loss = F.nll_loss(
F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1),
target=torch.flatten(true_actions, 0, 1),
reduction='none')
inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions)
return torch.sum(torch.mean(inverse_dynamics_loss, dim=1))