Skip to content

Commit 3af5eeb

Browse files
committed
Support GPUs in RaySampler
1 parent c56513f commit 3af5eeb

File tree

3 files changed

+62
-26
lines changed

3 files changed

+62
-26
lines changed

src/garage/sampler/ray_sampler.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@ class RaySampler(Sampler):
3737
The maximum length episodes which will be sampled.
3838
is_tf_worker (bool): Whether it is workers for TFTrainer.
3939
seed(int): The seed to use to initialize random number generators.
40-
n_workers(int): The number of workers to use.
40+
n_workers(int or None): The number of workers to use. Defaults to
41+
number of physical cpus, if worker_factory is also None.
4142
worker_class(type): Class of the workers. Instances should implement
4243
the Worker interface.
4344
worker_args (dict or None): Additional arguments that should be passed
4445
to the worker.
46+
n_gpus (int or float): Number of GPUs to to use in total for sampling.
47+
If `n_workers` is not a power of two, this may need to be set
48+
slightly below the true value, since `n_workers / n_gpus` gpus are
49+
allocated to each worker. Defaults to zero, because otherwise
50+
nothing would run if no gpus were available.
4551
4652
"""
4753

@@ -54,26 +60,32 @@ def __init__(
5460
max_episode_length=None,
5561
is_tf_worker=False,
5662
seed=get_seed(),
57-
n_workers=psutil.cpu_count(logical=False),
63+
n_workers=None,
5864
worker_class=DefaultWorker,
59-
worker_args=None):
60-
# pylint: disable=super-init-not-called
65+
worker_args=None,
66+
n_gpus=0):
6167
if not ray.is_initialized():
6268
ray.init(log_to_driver=False, ignore_reinit_error=True)
6369
if worker_factory is None and max_episode_length is None:
6470
raise TypeError('Must construct a sampler from WorkerFactory or'
6571
'parameters (at least max_episode_length)')
66-
if isinstance(worker_factory, WorkerFactory):
72+
if worker_factory is not None:
73+
if n_workers is None:
74+
n_workers = worker_factory.n_workers
6775
self._worker_factory = worker_factory
6876
else:
77+
if n_workers is None:
78+
n_workers = psutil.cpu_count(logical=False)
6979
self._worker_factory = WorkerFactory(
7080
max_episode_length=max_episode_length,
7181
is_tf_worker=is_tf_worker,
7282
seed=seed,
7383
n_workers=n_workers,
7484
worker_class=worker_class,
7585
worker_args=worker_args)
76-
self._sampler_worker = ray.remote(SamplerWorker)
86+
remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers)
87+
self._n_gpus = n_gpus
88+
self._sampler_worker = remote_wrapper(SamplerWorker)
7789
self._agents = agents
7890
self._envs = self._worker_factory.prepare_worker_messages(envs)
7991
self._all_workers = defaultdict(None)
@@ -103,7 +115,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
103115
Sampler: An instance of `cls`.
104116
105117
"""
106-
return cls(agents, envs, worker_factory=worker_factory)
118+
return cls(agents,
119+
envs,
120+
worker_factory=worker_factory,
121+
n_workers=worker_factory.n_workers)
107122

108123
def start_worker(self):
109124
"""Initialize a new ray worker."""
@@ -308,7 +323,8 @@ def __getstate__(self):
308323
"""
309324
return dict(factory=self._worker_factory,
310325
agents=self._agents,
311-
envs=self._envs)
326+
envs=self._envs,
327+
n_gpus=self._n_gpus)
312328

313329
def __setstate__(self, state):
314330
"""Unpickle the state.
@@ -319,7 +335,8 @@ def __setstate__(self, state):
319335
"""
320336
self.__init__(state['agents'],
321337
state['envs'],
322-
worker_factory=state['factory'])
338+
worker_factory=state['factory'],
339+
n_gpus=state['n_gpus'])
323340

324341

325342
class SamplerWorker:

src/garage/sampler/sampler.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,9 @@ class Sampler(abc.ABC):
1313
`Sampler` needs. Specifically, it specifies how to construct `Worker`s,
1414
which know how to collect episodes and update both agents and environments.
1515
16-
Currently, `__init__` is also part of the interface, but calling it is
17-
deprecated. `start_worker` is also deprecated, and does not need to be
18-
implemented.
16+
`start_worker` is deprecated, and does not need to be implemented.
1917
"""
2018

21-
def __init__(self, algo, env):
22-
"""Construct a Sampler from an Algorithm.
23-
24-
Args:
25-
algo (RLAlgorithm): The RL Algorithm controlling this
26-
sampler.
27-
env (Environment): The environment being sampled from.
28-
29-
Calling this method is deprecated.
30-
31-
"""
32-
self.algo = algo
33-
self.env = env
34-
3519
@classmethod
3620
def from_worker_factory(cls, worker_factory, agents, envs):
3721
"""Construct this sampler.

tests/garage/sampler/test_ray_batched_sampler.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for ray_batched_sampler."""
2+
import pickle
23
from unittest.mock import Mock
34

45
import numpy as np
@@ -138,6 +139,40 @@ def test_init_with_env_updates(ray_local_session_fixture):
138139
assert sum(episodes.lengths) >= 160
139140

140141

142+
def test_pickle(ray_local_session_fixture):
143+
del ray_local_session_fixture
144+
assert ray.is_initialized()
145+
max_episode_length = 16
146+
env = PointEnv()
147+
policy = FixedPolicy(env.spec,
148+
scripted_actions=[
149+
env.action_space.sample()
150+
for _ in range(max_episode_length)
151+
])
152+
tasks = SetTaskSampler(PointEnv)
153+
n_workers = 4
154+
workers = WorkerFactory(seed=100,
155+
max_episode_length=max_episode_length,
156+
n_workers=n_workers)
157+
sampler = RaySampler.from_worker_factory(workers, policy, env)
158+
sampler_pickled = pickle.dumps(sampler)
159+
sampler.shutdown_worker()
160+
sampler2 = pickle.loads(sampler_pickled)
161+
episodes = sampler2.obtain_samples(0,
162+
500,
163+
np.asarray(policy.get_param_values()),
164+
env_update=tasks.sample(n_workers))
165+
mean_rewards = []
166+
goals = []
167+
for eps in episodes.split():
168+
mean_rewards.append(eps.rewards.mean())
169+
goals.append(eps.env_infos['task'][0]['goal'])
170+
assert np.var(mean_rewards) > 0
171+
assert np.var(goals) > 0
172+
sampler2.shutdown_worker()
173+
env.close()
174+
175+
141176
def test_init_without_worker_factory(ray_local_session_fixture):
142177
del ray_local_session_fixture
143178
assert ray.is_initialized()

0 commit comments

Comments
 (0)