From a17e19aeb2ad8d7bc410cc755391c993a2abf174 Mon Sep 17 00:00:00 2001 From: nathankim7 Date: Tue, 23 Apr 2024 17:32:11 -0700 Subject: [PATCH] collected activations now return as dict --- pyvene/models/intervenable_base.py | 7 ++- pyvene_101.ipynb | 50 +++++++++++++++---- .../IntervenableBasicTestCase.py | 6 +-- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 14589192..314e06a5 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1102,8 +1102,11 @@ def _wait_for_forward_with_serial_intervention( unit_locations_base = unit_locations[group_key][1] if activations_sources != None: - for key in keys: - self.activations[key] = activations_sources[key] + for passed_in_key, v in activations_sources.items(): + assert ( + passed_in_key in self.sorted_keys + ), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}" + self.activations[passed_in_key] = torch.clone(v) else: keys_with_source = [ k for i, k in enumerate(keys) if unit_locations_source[i] != None diff --git a/pyvene_101.ipynb b/pyvene_101.ipynb index 5f2829e0..893cf0c8 100644 --- a/pyvene_101.ipynb +++ b/pyvene_101.ipynb @@ -126,10 +126,27 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0", - "metadata": {}, - "outputs": [], + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fce745d6f2ca453b98f7b10868b1ab7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "generation_config.json: 0%| | 0.00/124 [00:00