Skip to content

Commit 9ceabe8

Browse files
committed
Support GPUs in RaySampler
1 parent c56513f commit 9ceabe8

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

src/garage/sampler/ray_sampler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class RaySampler(Sampler):
4242
the Worker interface.
4343
worker_args (dict or None): Additional arguments that should be passed
4444
to the worker.
45+
n_gpus (int or float): Number of GPUs to to use in total for sampling.
46+
If `n_workers` is not a power of two, this may need to be set
47+
slightly below the true value, since `n_workers / n_gpus` gpus are
48+
allocated to each worker.
4549
4650
"""
4751

@@ -56,8 +60,8 @@ def __init__(
5660
seed=get_seed(),
5761
n_workers=psutil.cpu_count(logical=False),
5862
worker_class=DefaultWorker,
59-
worker_args=None):
60-
# pylint: disable=super-init-not-called
63+
worker_args=None,
64+
n_gpus=0):
6165
if not ray.is_initialized():
6266
ray.init(log_to_driver=False, ignore_reinit_error=True)
6367
if worker_factory is None and max_episode_length is None:
@@ -73,7 +77,8 @@ def __init__(
7377
n_workers=n_workers,
7478
worker_class=worker_class,
7579
worker_args=worker_args)
76-
self._sampler_worker = ray.remote(SamplerWorker)
80+
remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers)
81+
self._sampler_worker = remote_wrapper(SamplerWorker)
7782
self._agents = agents
7883
self._envs = self._worker_factory.prepare_worker_messages(envs)
7984
self._all_workers = defaultdict(None)
@@ -103,7 +108,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
103108
Sampler: An instance of `cls`.
104109
105110
"""
106-
return cls(agents, envs, worker_factory=worker_factory)
111+
return cls(agents,
112+
envs,
113+
worker_factory=worker_factory,
114+
n_workers=worker_factory.n_workers)
107115

108116
def start_worker(self):
109117
"""Initialize a new ray worker."""

src/garage/sampler/sampler.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,9 @@ class Sampler(abc.ABC):
1313
`Sampler` needs. Specifically, it specifies how to construct `Worker`s,
1414
which know how to collect episodes and update both agents and environments.
1515
16-
Currently, `__init__` is also part of the interface, but calling it is
17-
deprecated. `start_worker` is also deprecated, and does not need to be
18-
implemented.
16+
`start_worker` is deprecated, and does not need to be implemented.
1917
"""
2018

21-
def __init__(self, algo, env):
22-
"""Construct a Sampler from an Algorithm.
23-
24-
Args:
25-
algo (RLAlgorithm): The RL Algorithm controlling this
26-
sampler.
27-
env (Environment): The environment being sampled from.
28-
29-
Calling this method is deprecated.
30-
31-
"""
32-
self.algo = algo
33-
self.env = env
34-
3519
@classmethod
3620
def from_worker_factory(cls, worker_factory, agents, envs):
3721
"""Construct this sampler.

0 commit comments

Comments
 (0)