@@ -99,6 +99,7 @@ class ReplayBuffer:
99
99
is used with PyTree structures (see example below).
100
100
batch_size (int, optional): the batch size to be used when sample() is
101
101
called.
102
+
102
103
.. note::
103
104
The batch-size can be specified at construction time via the
104
105
``batch_size`` argument, or at sampling time. The former should
@@ -108,6 +109,7 @@ class ReplayBuffer:
108
109
incompatible with prefetching (since this requires to know the
109
110
batch-size in advance) as well as with samplers that have a
110
111
``drop_last`` argument.
112
+
111
113
dim_extend (int, optional): indicates the dim to consider for
112
114
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
113
115
When using ``dim_extend > 0``, we recommend using the ``ndim``
@@ -128,6 +130,7 @@ class ReplayBuffer:
128
130
>>> for d in data.unbind(1):
129
131
... rb.add(d)
130
132
>>> rb.extend(data)
133
+
131
134
generator (torch.Generator, optional): a generator to use for sampling.
132
135
Using a dedicated generator for the replay buffer can allow a fine-grained control
133
136
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -582,6 +585,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]):
582
585
583
586
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
584
587
be manually re-initialized every time the buffer is created.
588
+
585
589
"""
586
590
self ._storage .register_save_hook (hook )
587
591
@@ -926,15 +930,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
926
930
construct a tensordict from the non-tensordict content.
927
931
batch_size (int, optional): the batch size to be used when sample() is
928
932
called.
929
- .. note::
930
- The batch-size can be specified at construction time via the
933
+
934
+ .. note:: The batch-size can be specified at construction time via the
931
935
``batch_size`` argument, or at sampling time. The former should
932
936
be preferred whenever the batch-size is consistent across the
933
937
experiment. If the batch-size is likely to change, it can be
934
938
passed to the :meth:`sample` method. This option is
935
939
incompatible with prefetching (since this requires to know the
936
940
batch-size in advance) as well as with samplers that have a
937
941
``drop_last`` argument.
942
+
938
943
dim_extend (int, optional): indicates the dim to consider for
939
944
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
940
945
When using ``dim_extend > 0``, we recommend using the ``ndim``
@@ -1051,6 +1056,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
1051
1056
construct a tensordict from the non-tensordict content.
1052
1057
batch_size (int, optional): the batch size to be used when sample() is
1053
1058
called.
1059
+
1054
1060
.. note::
1055
1061
The batch-size can be specified at construction time via the
1056
1062
``batch_size`` argument, or at sampling time. The former should
@@ -1060,6 +1066,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
1060
1066
incompatible with prefetching (since this requires to know the
1061
1067
batch-size in advance) as well as with samplers that have a
1062
1068
``drop_last`` argument.
1069
+
1063
1070
priority_key (str, optional): the key at which priority is assumed to
1064
1071
be stored within TensorDicts added to this ReplayBuffer.
1065
1072
This is to be used when the sampler is of type
@@ -1085,6 +1092,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
1085
1092
>>> for d in data.unbind(1):
1086
1093
... rb.add(d)
1087
1094
>>> rb.extend(data)
1095
+
1088
1096
generator (torch.Generator, optional): a generator to use for sampling.
1089
1097
Using a dedicated generator for the replay buffer can allow a fine-grained control
1090
1098
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -1394,6 +1402,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
1394
1402
construct a tensordict from the non-tensordict content.
1395
1403
batch_size (int, optional): the batch size to be used when sample() is
1396
1404
called.
1405
+
1397
1406
.. note::
1398
1407
The batch-size can be specified at construction time via the
1399
1408
``batch_size`` argument, or at sampling time. The former should
@@ -1403,6 +1412,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
1403
1412
incompatible with prefetching (since this requires to know the
1404
1413
batch-size in advance) as well as with samplers that have a
1405
1414
``drop_last`` argument.
1415
+
1406
1416
priority_key (str, optional): the key at which priority is assumed to
1407
1417
be stored within TensorDicts added to this ReplayBuffer.
1408
1418
This is to be used when the sampler is of type
@@ -1431,6 +1441,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
1431
1441
>>> for d in data.unbind(1):
1432
1442
... rb.add(d)
1433
1443
>>> rb.extend(data)
1444
+
1434
1445
generator (torch.Generator, optional): a generator to use for sampling.
1435
1446
Using a dedicated generator for the replay buffer can allow a fine-grained control
1436
1447
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -1669,6 +1680,7 @@ class ReplayBufferEnsemble(ReplayBuffer):
1669
1680
Defaults to ``None`` (global default generator).
1670
1681
1671
1682
.. warning:: As of now, the generator has no effect on the transforms.
1683
+
1672
1684
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
1673
1685
Defaults to ``False``.
1674
1686
0 commit comments