Skip to content

Commit

Permalink
[Bug] Fix DDP for PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
lz1oceani committed Oct 18, 2023
1 parent 7b2c0a6 commit 4521de9
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 15 deletions.
5 changes: 3 additions & 2 deletions maniskill2_learn/apis/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def main_rl(rollout, evaluator, replay, args, cfg, expert_replay=None, recent_tr
agent = SpSyncBatchNorm.convert_sync_batchnorm(agent)
except:
pass
agent.to_ddp(device_ids=["cuda"])
is_ppo = "ppo" in cfg.agent_cfg.type.lower()
agent.to_ddp(device_ids=["cuda"], find_unused_parameters=not is_ppo)

logger.info(f"Work directory of this run {args.work_dir}")
if len(args.gpu_ids) > 0:
Expand Down Expand Up @@ -380,7 +381,7 @@ def run_one_process(rank, world_size, args, cfg):
# Only the first process will do evaluation
from maniskill2_learn.env import build_evaluation

logger.info(f"Build evaluation!")
logger.info("Build evaluation!")
eval_cfg = cfg.eval_cfg
# Evaluation environment setup can be different from the training set-up. (Like early-stop or object sets)
if eval_cfg.get("env_cfg", None) is None:
Expand Down
2 changes: 0 additions & 2 deletions maniskill2_learn/networks/backbones/pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from maniskill2_learn.utils.data import dict_to_seq, split_dim, GDict, repeat
from maniskill2_learn.utils.torch import masked_average, masked_max, ExtendedModule

from pytorch3d.transforms import quaternion_to_matrix


class STNkd(ExtendedModule):
def __init__(self, k=3, mlp_spec=[64, 128, 1024], norm_cfg=dict(type="BN1d", eps=1e-6), act_cfg=dict(type="ReLU")):
Expand Down
1 change: 0 additions & 1 deletion maniskill2_learn/networks/backbones/sp_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from pytorch3d.transforms import quaternion_to_matrix

from torchsparse import SparseTensor
import torchsparse.nn as spnn
Expand Down
4 changes: 0 additions & 4 deletions maniskill2_learn/networks/backbones/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@

from maniskill2_learn.utils.torch import ExtendedModule
from maniskill2_learn.utils.data import split_dim, GDict

from pytorch3d.transforms import quaternion_to_matrix

from .mlp import LinearMLP

import numpy as np
import open3d as o3d
import time


Expand Down
2 changes: 1 addition & 1 deletion maniskill2_learn/utils/meta/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _format_dict(input_dict, outest_level=False):
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(based_on_style="pep8", blank_line_before_nested_class_or_def=True, split_before_expression_after_opening_paren=True)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
text, _ = FormatCode(text, style_config=yapf_style)

return text

Expand Down
8 changes: 4 additions & 4 deletions maniskill2_learn/utils/torch/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def critic_grad(self, with_shared=True):
if self.actor.final_mlp is not None:
ret["grad/mlp_grad"] = self.actor.final_mlp.grad_norm

def to_ddp(self, device_ids=None):
def to_ddp(self, device_ids=None, find_unused_parameters=True):
self._device_ids = device_ids
self.recover_ddp()
self.recover_ddp(find_unused_parameters=find_unused_parameters)

def to_normal(self):
if self._be_data_parallel and self._device_ids is not None:
Expand All @@ -291,15 +291,15 @@ def to_normal(self):
if isinstance(item, DDP):
setattr(self, module_name, item.module)

def recover_ddp(self):
def recover_ddp(self, find_unused_parameters=True):
if self._device_ids is None:
return
self._be_data_parallel = True
for module_name in dir(self):
item = getattr(self, module_name)
if isinstance(item, ExtendedModule) and len(item.trainable_parameters) > 0:
if module_name not in self._tmp_attrs:
self._tmp_attrs[module_name] = ExtendedDDP(item, device_ids=self._device_ids, find_unused_parameters=True)
self._tmp_attrs[module_name] = ExtendedDDP(item, device_ids=self._device_ids, find_unused_parameters=find_unused_parameters)
setattr(self, module_name, self._tmp_attrs[module_name])

def is_data_parallel(self):
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ docker
crc32c
pypi-simple
numpy-quaternion
scikit-image==0.18.3
termcolor
pymeshlab
plyfile
Expand Down

0 comments on commit 4521de9

Please sign in to comment.