Skip to content

Commit 62b8c8b

Browse files
kurtamohlerVincent Moens
authored andcommitted
[Doc] Add docstring for MCTSForest.extend
ghstack-source-id: 7fa8834 Pull Request resolved: #2795 (cherry picked from commit a3a1ebe)
1 parent a1823af commit 62b8c8b

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

torchrl/data/map/tree.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,133 @@ def _make_node_map(self, source, dest):
967967
self.max_size = self.data_map.max_size
968968

969969
def extend(self, rollout, *, return_node: bool = False):
970+
"""Add a rollout to the forest.
971+
972+
Nodes are only added to a tree at points where rollouts diverge from
973+
each other and at the endpoints of rollouts.
974+
975+
If there is no existing tree that matches the first steps of the
976+
rollout, a new tree is added. Only one node is created, for the final
977+
step.
978+
979+
If there is an existing tree that matches, the rollout is added to that
980+
tree. If the rollout diverges from all other rollouts in the tree at
981+
some step, a new node is created before the step where the rollouts
982+
diverge, and a leaf node is created for the final step of the rollout.
983+
If all of the rollout's steps match with a previously added rollout,
984+
nothing changes. If the rollout matches up to a leaf node of a tree but
985+
continues beyond it, that node is extended to the end of the rollout,
986+
and no new nodes are created.
987+
988+
Args:
989+
rollout (TensorDict): The rollout to add to the forest.
990+
return_node (bool, optional): If ``True``, the method returns the
991+
added node. Default is ``False``.
992+
993+
Returns:
994+
Tree: The node that was added to the forest. This is only
995+
returned if ``return_node`` is True.
996+
997+
Examples:
998+
>>> from torchrl.data import MCTSForest
999+
>>> from tensordict import TensorDict
1000+
>>> import torch
1001+
>>> forest = MCTSForest()
1002+
>>> r0 = TensorDict({
1003+
... 'action': torch.tensor([1, 2, 3, 4, 5]),
1004+
... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
1005+
... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
1006+
... }, [5])
1007+
>>> r1 = TensorDict({
1008+
... 'action': torch.tensor([1, 2, 6, 7]),
1009+
... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
1010+
... 'observation': torch.tensor([ 0, 123, 392, 235])
1011+
... }, [4])
1012+
>>> td_root = r0[0].exclude("next")
1013+
>>> forest.extend(r0)
1014+
>>> forest.extend(r1)
1015+
>>> tree = forest.get_tree(td_root)
1016+
>>> print(tree)
1017+
Tree(
1018+
count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
1019+
index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1020+
node_data=TensorDict(
1021+
fields={
1022+
observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
1023+
batch_size=torch.Size([]),
1024+
device=cpu,
1025+
is_shared=False),
1026+
node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
1027+
rollout=TensorDict(
1028+
fields={
1029+
action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1030+
next: TensorDict(
1031+
fields={
1032+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1033+
batch_size=torch.Size([2]),
1034+
device=cpu,
1035+
is_shared=False),
1036+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1037+
batch_size=torch.Size([2]),
1038+
device=cpu,
1039+
is_shared=False),
1040+
subtree=Tree(
1041+
_parent=NonTensorStack(
1042+
[<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
1043+
batch_size=torch.Size([2]),
1044+
device=None),
1045+
count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
1046+
hash=NonTensorStack(
1047+
[4341220243998689835, 6745467818783115365],
1048+
batch_size=torch.Size([2]),
1049+
device=None),
1050+
node_data=LazyStackedTensorDict(
1051+
fields={
1052+
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1053+
exclusive_fields={
1054+
},
1055+
batch_size=torch.Size([2]),
1056+
device=cpu,
1057+
is_shared=False,
1058+
stack_dim=0),
1059+
node_id=NonTensorStack(
1060+
[1, 2],
1061+
batch_size=torch.Size([2]),
1062+
device=None),
1063+
rollout=LazyStackedTensorDict(
1064+
fields={
1065+
action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
1066+
next: LazyStackedTensorDict(
1067+
fields={
1068+
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1069+
exclusive_fields={
1070+
},
1071+
batch_size=torch.Size([2, -1]),
1072+
device=cpu,
1073+
is_shared=False,
1074+
stack_dim=0),
1075+
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1076+
exclusive_fields={
1077+
},
1078+
batch_size=torch.Size([2, -1]),
1079+
device=cpu,
1080+
is_shared=False,
1081+
stack_dim=0),
1082+
wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
1083+
index=None,
1084+
subtree=None,
1085+
specs=None,
1086+
batch_size=torch.Size([2]),
1087+
device=None,
1088+
is_shared=False),
1089+
wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1090+
hash=None,
1091+
_parent=None,
1092+
specs=None,
1093+
batch_size=torch.Size([]),
1094+
device=None,
1095+
is_shared=False)
1096+
"""
9701097
source, dest = (
9711098
rollout.exclude("next").copy(),
9721099
rollout.select("next", *self.action_keys).copy(),

0 commit comments

Comments
 (0)