@@ -798,6 +798,10 @@ def done_keys(self) -> List[NestedKey]:
798
798
799
799
@done_keys .setter
800
800
def done_keys (self , value ):
801
+ if isinstance (value , (str , tuple )):
802
+ value = [value ]
803
+ if value is not None :
804
+ value = [unravel_key (val ) for val in value ]
801
805
self ._done_keys = _make_list_of_nestedkeys (value , "done_keys" )
802
806
803
807
@property
@@ -818,6 +822,10 @@ def reward_keys(self) -> List[NestedKey]:
818
822
819
823
@reward_keys .setter
820
824
def reward_keys (self , value ):
825
+ if isinstance (value , (str , tuple )):
826
+ value = [value ]
827
+ if value is not None :
828
+ value = [unravel_key (val ) for val in value ]
821
829
self ._reward_keys = _make_list_of_nestedkeys (value , "reward_keys" )
822
830
823
831
@property
@@ -838,6 +846,10 @@ def action_keys(self) -> List[NestedKey]:
838
846
839
847
@action_keys .setter
840
848
def action_keys (self , value ):
849
+ if isinstance (value , (str , tuple )):
850
+ value = [value ]
851
+ if value is not None :
852
+ value = [unravel_key (val ) for val in value ]
841
853
self ._action_keys = _make_list_of_nestedkeys (value , "action_keys" )
842
854
843
855
@property
@@ -857,6 +869,10 @@ def observation_keys(self) -> List[NestedKey]:
857
869
858
870
@observation_keys .setter
859
871
def observation_keys (self , value ):
872
+ if isinstance (value , (str , tuple )):
873
+ value = [value ]
874
+ if value is not None :
875
+ value = [unravel_key (val ) for val in value ]
860
876
self ._observation_keys = _make_list_of_nestedkeys (value , "observation_keys" )
861
877
862
878
@property
@@ -1012,6 +1028,27 @@ def add(self, step, *, return_node: bool = False):
1012
1028
if return_node :
1013
1029
return self .get_tree (step )
1014
1030
1031
+ def add (self , step ):
1032
+ source , dest = (
1033
+ step .exclude ("next" ).copy (),
1034
+ step .select ("next" , * self .action_keys ).copy (),
1035
+ )
1036
+
1037
+ if self .data_map is None :
1038
+ self ._make_storage (source , dest )
1039
+
1040
+ # We need to set the action somewhere to keep track of what action lead to what child
1041
+ # # Set the action in the 'next'
1042
+ # dest[1:] = source[:-1].exclude(*self.done_keys)
1043
+
1044
+ # Add ('observation', 'action') -> ('next, observation')
1045
+ self .data_map [source ] = dest
1046
+ value = source
1047
+ if self .node_map is None :
1048
+ self ._make_storage_branches (source , dest )
1049
+ # map ('observation',) -> ('indices',)
1050
+ self .node_map [source ] = value
1051
+
1015
1052
def get_child (self , root : TensorDictBase ) -> TensorDictBase :
1016
1053
return self .data_map [root ]
1017
1054
0 commit comments