Skip to content

Commit 188ed86

Browse files
author
Vincent Moens
committed
[Tutorial] Beam search with GPT models
ghstack-source-id: 36de30d Pull Request resolved: #2623
1 parent 133d709 commit 188ed86

File tree

5 files changed

+456
-0
lines changed

5 files changed

+456
-0
lines changed

docs/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ vmas
2828
onnxscript
2929
onnxruntime
3030
onnx
31+
plotly
32+
igraph
33+
transformers
34+
datasets
318 KB
Loading

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Intermediate
105105
tutorials/dqn_with_rnn
106106
tutorials/rb_tutorial
107107
tutorials/export
108+
tutorials/beam_search_with_gpt
108109

109110
Advanced
110111
--------

torchrl/data/map/tree.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,10 @@ def done_keys(self) -> List[NestedKey]:
798798

799799
@done_keys.setter
800800
def done_keys(self, value):
801+
if isinstance(value, (str, tuple)):
802+
value = [value]
803+
if value is not None:
804+
value = [unravel_key(val) for val in value]
801805
self._done_keys = _make_list_of_nestedkeys(value, "done_keys")
802806

803807
@property
@@ -818,6 +822,10 @@ def reward_keys(self) -> List[NestedKey]:
818822

819823
@reward_keys.setter
820824
def reward_keys(self, value):
825+
if isinstance(value, (str, tuple)):
826+
value = [value]
827+
if value is not None:
828+
value = [unravel_key(val) for val in value]
821829
self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys")
822830

823831
@property
@@ -838,6 +846,10 @@ def action_keys(self) -> List[NestedKey]:
838846

839847
@action_keys.setter
840848
def action_keys(self, value):
849+
if isinstance(value, (str, tuple)):
850+
value = [value]
851+
if value is not None:
852+
value = [unravel_key(val) for val in value]
841853
self._action_keys = _make_list_of_nestedkeys(value, "action_keys")
842854

843855
@property
@@ -857,6 +869,10 @@ def observation_keys(self) -> List[NestedKey]:
857869

858870
@observation_keys.setter
859871
def observation_keys(self, value):
872+
if isinstance(value, (str, tuple)):
873+
value = [value]
874+
if value is not None:
875+
value = [unravel_key(val) for val in value]
860876
self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys")
861877

862878
@property
@@ -1012,6 +1028,27 @@ def add(self, step, *, return_node: bool = False):
10121028
if return_node:
10131029
return self.get_tree(step)
10141030

1031+
def add(self, step):
1032+
source, dest = (
1033+
step.exclude("next").copy(),
1034+
step.select("next", *self.action_keys).copy(),
1035+
)
1036+
1037+
if self.data_map is None:
1038+
self._make_storage(source, dest)
1039+
1040+
# We need to set the action somewhere to keep track of what action lead to what child
1041+
# # Set the action in the 'next'
1042+
# dest[1:] = source[:-1].exclude(*self.done_keys)
1043+
1044+
# Add ('observation', 'action') -> ('next, observation')
1045+
self.data_map[source] = dest
1046+
value = source
1047+
if self.node_map is None:
1048+
self._make_storage_branches(source, dest)
1049+
# map ('observation',) -> ('indices',)
1050+
self.node_map[source] = value
1051+
10151052
def get_child(self, root: TensorDictBase) -> TensorDictBase:
10161053
return self.data_map[root]
10171054

0 commit comments

Comments
 (0)