@@ -42,6 +42,10 @@ class RaySampler(Sampler):
42
42
the Worker interface.
43
43
worker_args (dict or None): Additional arguments that should be passed
44
44
to the worker.
45
+ n_gpus (int or float): Number of GPUs to to use in total for sampling.
46
+ If `n_workers` is not a power of two, this may need to be set
47
+ slightly below the true value, since `n_workers / n_gpus` gpus are
48
+ allocated to each worker.
45
49
46
50
"""
47
51
@@ -56,8 +60,8 @@ def __init__(
56
60
seed = get_seed (),
57
61
n_workers = psutil .cpu_count (logical = False ),
58
62
worker_class = DefaultWorker ,
59
- worker_args = None ):
60
- # pylint: disable=super-init-not-called
63
+ worker_args = None ,
64
+ n_gpus = 0 ):
61
65
if not ray .is_initialized ():
62
66
ray .init (log_to_driver = False , ignore_reinit_error = True )
63
67
if worker_factory is None and max_episode_length is None :
@@ -73,7 +77,8 @@ def __init__(
73
77
n_workers = n_workers ,
74
78
worker_class = worker_class ,
75
79
worker_args = worker_args )
76
- self ._sampler_worker = ray .remote (SamplerWorker )
80
+ remote_wrapper = ray .remote (num_gpus = n_gpus / n_workers )
81
+ self ._sampler_worker = remote_wrapper (SamplerWorker )
77
82
self ._agents = agents
78
83
self ._envs = self ._worker_factory .prepare_worker_messages (envs )
79
84
self ._all_workers = defaultdict (None )
@@ -103,7 +108,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
103
108
Sampler: An instance of `cls`.
104
109
105
110
"""
106
- return cls (agents , envs , worker_factory = worker_factory )
111
+ return cls (agents ,
112
+ envs ,
113
+ worker_factory = worker_factory ,
114
+ n_workers = worker_factory .n_workers )
107
115
108
116
def start_worker (self ):
109
117
"""Initialize a new ray worker."""
0 commit comments