@@ -37,11 +37,17 @@ class RaySampler(Sampler):
37
37
The maximum length episodes which will be sampled.
38
38
is_tf_worker (bool): Whether it is workers for TFTrainer.
39
39
seed(int): The seed to use to initialize random number generators.
40
- n_workers(int): The number of workers to use.
40
+ n_workers(int or None): The number of workers to use. Defaults to
41
+ number of physical cpus, if worker_factory is also None.
41
42
worker_class(type): Class of the workers. Instances should implement
42
43
the Worker interface.
43
44
worker_args (dict or None): Additional arguments that should be passed
44
45
to the worker.
46
+ n_gpus (int or float): Number of GPUs to to use in total for sampling.
47
+ If `n_workers` is not a power of two, this may need to be set
48
+ slightly below the true value, since `n_workers / n_gpus` gpus are
49
+ allocated to each worker. Defaults to zero, because otherwise
50
+ nothing would run if no gpus were available.
45
51
46
52
"""
47
53
@@ -54,26 +60,32 @@ def __init__(
54
60
max_episode_length = None ,
55
61
is_tf_worker = False ,
56
62
seed = get_seed (),
57
- n_workers = psutil . cpu_count ( logical = False ) ,
63
+ n_workers = None ,
58
64
worker_class = DefaultWorker ,
59
- worker_args = None ):
60
- # pylint: disable=super-init-not-called
65
+ worker_args = None ,
66
+ n_gpus = 0 ):
61
67
if not ray .is_initialized ():
62
68
ray .init (log_to_driver = False , ignore_reinit_error = True )
63
69
if worker_factory is None and max_episode_length is None :
64
70
raise TypeError ('Must construct a sampler from WorkerFactory or'
65
71
'parameters (at least max_episode_length)' )
66
- if isinstance (worker_factory , WorkerFactory ):
72
+ if worker_factory is not None :
73
+ if n_workers is None :
74
+ n_workers = worker_factory .n_workers
67
75
self ._worker_factory = worker_factory
68
76
else :
77
+ if n_workers is None :
78
+ n_workers = psutil .cpu_count (logical = False )
69
79
self ._worker_factory = WorkerFactory (
70
80
max_episode_length = max_episode_length ,
71
81
is_tf_worker = is_tf_worker ,
72
82
seed = seed ,
73
83
n_workers = n_workers ,
74
84
worker_class = worker_class ,
75
85
worker_args = worker_args )
76
- self ._sampler_worker = ray .remote (SamplerWorker )
86
+ remote_wrapper = ray .remote (num_gpus = n_gpus / n_workers )
87
+ self ._n_gpus = n_gpus
88
+ self ._sampler_worker = remote_wrapper (SamplerWorker )
77
89
self ._agents = agents
78
90
self ._envs = self ._worker_factory .prepare_worker_messages (envs )
79
91
self ._all_workers = defaultdict (None )
@@ -103,7 +115,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
103
115
Sampler: An instance of `cls`.
104
116
105
117
"""
106
- return cls (agents , envs , worker_factory = worker_factory )
118
+ return cls (agents ,
119
+ envs ,
120
+ worker_factory = worker_factory ,
121
+ n_workers = worker_factory .n_workers )
107
122
108
123
def start_worker (self ):
109
124
"""Initialize a new ray worker."""
@@ -308,7 +323,8 @@ def __getstate__(self):
308
323
"""
309
324
return dict (factory = self ._worker_factory ,
310
325
agents = self ._agents ,
311
- envs = self ._envs )
326
+ envs = self ._envs ,
327
+ n_gpus = self ._n_gpus )
312
328
313
329
def __setstate__ (self , state ):
314
330
"""Unpickle the state.
@@ -319,7 +335,8 @@ def __setstate__(self, state):
319
335
"""
320
336
self .__init__ (state ['agents' ],
321
337
state ['envs' ],
322
- worker_factory = state ['factory' ])
338
+ worker_factory = state ['factory' ],
339
+ n_gpus = state ['n_gpus' ])
323
340
324
341
325
342
class SamplerWorker :
0 commit comments