Skip to content

Commit

Permalink
collected activations now return as dict
Browse files Browse the repository at this point in the history
  • Loading branch information
nathankim7 committed Jun 7, 2024
1 parent 7f099fb commit a17e19a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
7 changes: 5 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,8 +1102,11 @@ def _wait_for_forward_with_serial_intervention(
unit_locations_base = unit_locations[group_key][1]

if activations_sources != None:
for key in keys:
self.activations[key] = activations_sources[key]
for passed_in_key, v in activations_sources.items():
assert (
passed_in_key in self.sorted_keys
), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}"
self.activations[passed_in_key] = torch.clone(v)
else:
keys_with_source = [
k for i, k in enumerate(keys) if unit_locations_source[i] != None
Expand Down
50 changes: 40 additions & 10 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,27 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
"metadata": {},
"outputs": [],
"metadata": {
"metadata": {}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fce745d6f2ca453b98f7b10868b1ab7d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pyvene as pv\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
Expand All @@ -144,10 +161,10 @@
" \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n",
"\n",
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
"collected_attn_w = pv_gpt2(\n",
"(_, collected_attn_w), _ = pv_gpt2(\n",
" base = tokenizer(base, return_tensors=\"pt\"\n",
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
")[0][-1][0]"
")"
]
},
{
Expand All @@ -160,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "128be2dd-f089-4291-bfc5-7002d031b1e9",
"metadata": {
"metadata": {}
Expand All @@ -171,7 +188,7 @@
"output_type": "stream",
"text": [
"loaded GPT2 model gpt2\n",
"torch.Size([12, 14, 14])\n"
"torch.Size([1, 12, 14, 14])\n"
]
}
],
Expand All @@ -193,6 +210,7 @@
" base = tokenizer(base, return_tensors=\"pt\"\n",
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
")\n",
"collected_attn_w = torch.stack(list(collected_attn_w.values()))\n",
"print(collected_attn_w[0].shape)"
]
},
Expand All @@ -206,16 +224,28 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 5,
"id": "678dc46f",
"metadata": {},
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n"
"loaded GPT2 model gpt2\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/IntervenableBasicTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,10 @@ def test_customized_intervention_function_get(self):
)

base = "When John and Mary went to the shops, Mary gave the bag to"
collected_attn_w = pv_gpt2(
(_, collected_attn_w), _ = pv_gpt2(
base=tokenizer(base, return_tensors="pt"),
unit_locations={"base": [h for h in range(12)]},
)[0][-1][0]
)

cached_w = {}

Expand All @@ -608,7 +608,7 @@ def pv_patcher(b, s):

base = "When John and Mary went to the shops, Mary gave the bag to"
_ = pv_gpt2(tokenizer(base, return_tensors="pt"))
torch.allclose(collected_attn_w, cached_w["attn_w"].unsqueeze(dim=0))
torch.allclose(list(collected_attn_w.values())[0], cached_w["attn_w"].unsqueeze(dim=0))

def test_customized_intervention_function_zeroout(self):

Expand Down

0 comments on commit a17e19a

Please sign in to comment.