@@ -967,6 +967,133 @@ def _make_node_map(self, source, dest):
967
967
self .max_size = self .data_map .max_size
968
968
969
969
def extend (self , rollout , * , return_node : bool = False ):
970
+ """Add a rollout to the forest.
971
+
972
+ Nodes are only added to a tree at points where rollouts diverge from
973
+ each other and at the endpoints of rollouts.
974
+
975
+ If there is no existing tree that matches the first steps of the
976
+ rollout, a new tree is added. Only one node is created, for the final
977
+ step.
978
+
979
+ If there is an existing tree that matches, the rollout is added to that
980
+ tree. If the rollout diverges from all other rollouts in the tree at
981
+ some step, a new node is created before the step where the rollouts
982
+ diverge, and a leaf node is created for the final step of the rollout.
983
+ If all of the rollout's steps match with a previously added rollout,
984
+ nothing changes. If the rollout matches up to a leaf node of a tree but
985
+ continues beyond it, that node is extended to the end of the rollout,
986
+ and no new nodes are created.
987
+
988
+ Args:
989
+ rollout (TensorDict): The rollout to add to the forest.
990
+ return_node (bool, optional): If ``True``, the method returns the
991
+ added node. Default is ``False``.
992
+
993
+ Returns:
994
+ Tree: The node that was added to the forest. This is only
995
+ returned if ``return_node`` is True.
996
+
997
+ Examples:
998
+ >>> from torchrl.data import MCTSForest
999
+ >>> from tensordict import TensorDict
1000
+ >>> import torch
1001
+ >>> forest = MCTSForest()
1002
+ >>> r0 = TensorDict({
1003
+ ... 'action': torch.tensor([1, 2, 3, 4, 5]),
1004
+ ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
1005
+ ... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
1006
+ ... }, [5])
1007
+ >>> r1 = TensorDict({
1008
+ ... 'action': torch.tensor([1, 2, 6, 7]),
1009
+ ... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
1010
+ ... 'observation': torch.tensor([ 0, 123, 392, 235])
1011
+ ... }, [4])
1012
+ >>> td_root = r0[0].exclude("next")
1013
+ >>> forest.extend(r0)
1014
+ >>> forest.extend(r1)
1015
+ >>> tree = forest.get_tree(td_root)
1016
+ >>> print(tree)
1017
+ Tree(
1018
+ count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
1019
+ index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1020
+ node_data=TensorDict(
1021
+ fields={
1022
+ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
1023
+ batch_size=torch.Size([]),
1024
+ device=cpu,
1025
+ is_shared=False),
1026
+ node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
1027
+ rollout=TensorDict(
1028
+ fields={
1029
+ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1030
+ next: TensorDict(
1031
+ fields={
1032
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1033
+ batch_size=torch.Size([2]),
1034
+ device=cpu,
1035
+ is_shared=False),
1036
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1037
+ batch_size=torch.Size([2]),
1038
+ device=cpu,
1039
+ is_shared=False),
1040
+ subtree=Tree(
1041
+ _parent=NonTensorStack(
1042
+ [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
1043
+ batch_size=torch.Size([2]),
1044
+ device=None),
1045
+ count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
1046
+ hash=NonTensorStack(
1047
+ [4341220243998689835, 6745467818783115365],
1048
+ batch_size=torch.Size([2]),
1049
+ device=None),
1050
+ node_data=LazyStackedTensorDict(
1051
+ fields={
1052
+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1053
+ exclusive_fields={
1054
+ },
1055
+ batch_size=torch.Size([2]),
1056
+ device=cpu,
1057
+ is_shared=False,
1058
+ stack_dim=0),
1059
+ node_id=NonTensorStack(
1060
+ [1, 2],
1061
+ batch_size=torch.Size([2]),
1062
+ device=None),
1063
+ rollout=LazyStackedTensorDict(
1064
+ fields={
1065
+ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
1066
+ next: LazyStackedTensorDict(
1067
+ fields={
1068
+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1069
+ exclusive_fields={
1070
+ },
1071
+ batch_size=torch.Size([2, -1]),
1072
+ device=cpu,
1073
+ is_shared=False,
1074
+ stack_dim=0),
1075
+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1076
+ exclusive_fields={
1077
+ },
1078
+ batch_size=torch.Size([2, -1]),
1079
+ device=cpu,
1080
+ is_shared=False,
1081
+ stack_dim=0),
1082
+ wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
1083
+ index=None,
1084
+ subtree=None,
1085
+ specs=None,
1086
+ batch_size=torch.Size([2]),
1087
+ device=None,
1088
+ is_shared=False),
1089
+ wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1090
+ hash=None,
1091
+ _parent=None,
1092
+ specs=None,
1093
+ batch_size=torch.Size([]),
1094
+ device=None,
1095
+ is_shared=False)
1096
+ """
970
1097
source , dest = (
971
1098
rollout .exclude ("next" ).copy (),
972
1099
rollout .select ("next" , * self .action_keys ).copy (),
0 commit comments