Skip to content

Commit e42ffd3

Browse files
author
Vincent Moens
committed
[Doc] Fix formatting errors
ghstack-source-id: ac1f3da Pull Request resolved: #2786 (cherry picked from commit 03d6586)
1 parent 882dc79 commit e42ffd3

File tree

10 files changed

+45
-31
lines changed

10 files changed

+45
-31
lines changed

torchrl/data/datasets/atari_dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
6060
root (Path or str, optional): The AtariDQN dataset root directory.
6161
The actual dataset memory-mapped files will be saved under
6262
`<root>/<dataset_id>`. If none is provided, it defaults to
63-
``~/.cache/torchrl/atari`.
63+
`~/.cache/torchrl/atari`.atari`.
6464
num_procs (int, optional): number of processes to launch for preprocessing.
6565
Has no effect whenever the data is already downloaded. Defaults to 0
6666
(no multiprocessing used).

torchrl/data/datasets/d4rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class D4RLExperienceReplay(BaseDatasetExperienceReplay):
106106
root (Path or str, optional): The D4RL dataset root directory.
107107
The actual dataset memory-mapped files will be saved under
108108
`<root>/<dataset_id>`. If none is provided, it defaults to
109-
``~/.cache/torchrl/d4rl`.
109+
`~/.cache/torchrl/atari`.d4rl`.
110110
download (bool, optional): Whether the dataset should be downloaded if
111111
not found. Defaults to ``True``.
112112
**env_kwargs (key-value pairs): additional kwargs for

torchrl/data/datasets/gen_dgrl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class GenDGRLExperienceReplay(BaseDatasetExperienceReplay):
6060
dataset root directory.
6161
The actual dataset memory-mapped files will be saved under
6262
`<root>/<dataset_id>`. If none is provided, it defaults to
63-
``~/.cache/torchrl/gen_dgrl`.
63+
`~/.cache/torchrl/atari`.gen_dgrl`.
6464
download (bool or str, optional): Whether the dataset should be downloaded if
6565
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
6666
in which case the downloaded data will be overwritten.

torchrl/data/datasets/minari_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
6666
root (Path or str, optional): The Minari dataset root directory.
6767
The actual dataset memory-mapped files will be saved under
6868
`<root>/<dataset_id>`. If none is provided, it defaults to
69-
``~/.cache/torchrl/minari`.
69+
`~/.cache/torchrl/atari`.minari`.
7070
download (bool or str, optional): Whether the dataset should be downloaded if
7171
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
7272
in which case the downloaded data will be overwritten.

torchrl/data/datasets/openx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class for more information on how to interact with non-tensor data
123123
root (Path or str, optional): The OpenX dataset root directory.
124124
The actual dataset memory-mapped files will be saved under
125125
`<root>/<dataset_id>`. If none is provided, it defaults to
126-
``~/.cache/torchrl/openx`.
126+
`~/.cache/torchrl/atari`.openx`.
127127
streaming (bool, optional): if ``True``, the data won't be downloaded but
128128
read from a stream instead.
129129

torchrl/data/datasets/roboset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class RobosetExperienceReplay(BaseDatasetExperienceReplay):
5757
root (Path or str, optional): The Roboset dataset root directory.
5858
The actual dataset memory-mapped files will be saved under
5959
`<root>/<dataset_id>`. If none is provided, it defaults to
60-
``~/.cache/torchrl/roboset`.
60+
`~/.cache/torchrl/atari`.roboset`.
6161
download (bool or str, optional): Whether the dataset should be downloaded if
6262
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
6363
in which case the downloaded data will be overwritten.

torchrl/data/datasets/vd4rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class VD4RLExperienceReplay(BaseDatasetExperienceReplay):
6363
root (Path or str, optional): The V-D4RL dataset root directory.
6464
The actual dataset memory-mapped files will be saved under
6565
`<root>/<dataset_id>`. If none is provided, it defaults to
66-
``~/.cache/torchrl/vd4rl`.
66+
`~/.cache/torchrl/atari`.vd4rl`.
6767
download (bool or str, optional): Whether the dataset should be downloaded if
6868
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
6969
in which case the downloaded data will be overwritten.

torchrl/data/map/query.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,32 +80,33 @@ class QueryModule(TensorDictModuleBase):
8080
If a single ``hash_module`` is provided but no aggregator is passed, it will take
8181
the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
8282
provided but no aggregator is passed, it will default to ``SipHash``.
83-
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
83+
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
8484
returned. This can be used to retrieve the integer index within the storage,
8585
corresponding to a given input tensordict. This can be overridden at runtime by
8686
providing the ``clone`` argument to the forward method.
8787
Defaults to ``False``.
88-
d
89-
Examples:
90-
>>> query_module = QueryModule(
91-
... in_keys=["key1", "key2"],
92-
... index_key="index",
93-
... hash_module=SipHash(),
94-
... )
95-
>>> query = TensorDict(
96-
... {
97-
... "key1": torch.Tensor([[1], [1], [1], [2]]),
98-
... "key2": torch.Tensor([[3], [3], [2], [3]]),
99-
... "other": torch.randn(4),
100-
... },
101-
... batch_size=(4,),
102-
... )
103-
>>> res = query_module(query)
104-
>>> # The first two pairs of key1 and key2 match
105-
>>> assert res["index"][0] == res["index"][1]
106-
>>> # The last three pairs of key1 and key2 have at least one mismatching value
107-
>>> assert res["index"][1] != res["index"][2]
108-
>>> assert res["index"][2] != res["index"][3]
88+
89+
Examples:
90+
>>> query_module = QueryModule(
91+
... in_keys=["key1", "key2"],
92+
... index_key="index",
93+
... hash_module=SipHash(),
94+
... )
95+
>>> query = TensorDict(
96+
... {
97+
... "key1": torch.Tensor([[1], [1], [1], [2]]),
98+
... "key2": torch.Tensor([[3], [3], [2], [3]]),
99+
... "other": torch.randn(4),
100+
... },
101+
... batch_size=(4,),
102+
... )
103+
>>> res = query_module(query)
104+
>>> # The first two pairs of key1 and key2 match
105+
>>> assert res["index"][0] == res["index"][1]
106+
>>> # The last three pairs of key1 and key2 have at least one mismatching value
107+
>>> assert res["index"][1] != res["index"][2]
108+
>>> assert res["index"][2] != res["index"][3]
109+
109110
"""
110111

111112
def __init__(

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class ReplayBuffer:
9999
is used with PyTree structures (see example below).
100100
batch_size (int, optional): the batch size to be used when sample() is
101101
called.
102+
102103
.. note::
103104
The batch-size can be specified at construction time via the
104105
``batch_size`` argument, or at sampling time. The former should
@@ -108,6 +109,7 @@ class ReplayBuffer:
108109
incompatible with prefetching (since this requires to know the
109110
batch-size in advance) as well as with samplers that have a
110111
``drop_last`` argument.
112+
111113
dim_extend (int, optional): indicates the dim to consider for
112114
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
113115
When using ``dim_extend > 0``, we recommend using the ``ndim``
@@ -128,6 +130,7 @@ class ReplayBuffer:
128130
>>> for d in data.unbind(1):
129131
... rb.add(d)
130132
>>> rb.extend(data)
133+
131134
generator (torch.Generator, optional): a generator to use for sampling.
132135
Using a dedicated generator for the replay buffer can allow a fine-grained control
133136
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -582,6 +585,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]):
582585
583586
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
584587
be manually re-initialized every time the buffer is created.
588+
585589
"""
586590
self._storage.register_save_hook(hook)
587591

@@ -926,15 +930,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
926930
construct a tensordict from the non-tensordict content.
927931
batch_size (int, optional): the batch size to be used when sample() is
928932
called.
929-
.. note::
930-
The batch-size can be specified at construction time via the
933+
934+
.. note:: The batch-size can be specified at construction time via the
931935
``batch_size`` argument, or at sampling time. The former should
932936
be preferred whenever the batch-size is consistent across the
933937
experiment. If the batch-size is likely to change, it can be
934938
passed to the :meth:`sample` method. This option is
935939
incompatible with prefetching (since this requires to know the
936940
batch-size in advance) as well as with samplers that have a
937941
``drop_last`` argument.
942+
938943
dim_extend (int, optional): indicates the dim to consider for
939944
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
940945
When using ``dim_extend > 0``, we recommend using the ``ndim``
@@ -1051,6 +1056,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
10511056
construct a tensordict from the non-tensordict content.
10521057
batch_size (int, optional): the batch size to be used when sample() is
10531058
called.
1059+
10541060
.. note::
10551061
The batch-size can be specified at construction time via the
10561062
``batch_size`` argument, or at sampling time. The former should
@@ -1060,6 +1066,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
10601066
incompatible with prefetching (since this requires to know the
10611067
batch-size in advance) as well as with samplers that have a
10621068
``drop_last`` argument.
1069+
10631070
priority_key (str, optional): the key at which priority is assumed to
10641071
be stored within TensorDicts added to this ReplayBuffer.
10651072
This is to be used when the sampler is of type
@@ -1085,6 +1092,7 @@ class TensorDictReplayBuffer(ReplayBuffer):
10851092
>>> for d in data.unbind(1):
10861093
... rb.add(d)
10871094
>>> rb.extend(data)
1095+
10881096
generator (torch.Generator, optional): a generator to use for sampling.
10891097
Using a dedicated generator for the replay buffer can allow a fine-grained control
10901098
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -1394,6 +1402,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
13941402
construct a tensordict from the non-tensordict content.
13951403
batch_size (int, optional): the batch size to be used when sample() is
13961404
called.
1405+
13971406
.. note::
13981407
The batch-size can be specified at construction time via the
13991408
``batch_size`` argument, or at sampling time. The former should
@@ -1403,6 +1412,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
14031412
incompatible with prefetching (since this requires to know the
14041413
batch-size in advance) as well as with samplers that have a
14051414
``drop_last`` argument.
1415+
14061416
priority_key (str, optional): the key at which priority is assumed to
14071417
be stored within TensorDicts added to this ReplayBuffer.
14081418
This is to be used when the sampler is of type
@@ -1431,6 +1441,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
14311441
>>> for d in data.unbind(1):
14321442
... rb.add(d)
14331443
>>> rb.extend(data)
1444+
14341445
generator (torch.Generator, optional): a generator to use for sampling.
14351446
Using a dedicated generator for the replay buffer can allow a fine-grained control
14361447
over seeding, for instance keeping the global seed different but the RB seed identical
@@ -1669,6 +1680,7 @@ class ReplayBufferEnsemble(ReplayBuffer):
16691680
Defaults to ``None`` (global default generator).
16701681
16711682
.. warning:: As of now, the generator has no effect on the transforms.
1683+
16721684
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
16731685
Defaults to ``False``.
16741686

torchrl/data/rlhf/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class RolloutFromModel:
198198
batch_size=torch.Size([4, 50]),
199199
device=cpu,
200200
is_shared=False)
201+
201202
"""
202203

203204
EOS_TOKEN_ID = 50256

0 commit comments

Comments
 (0)