Skip to content

Commit

Permalink
intervention doesnt delete your passed reprs, also generation notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
nathankim7 committed Jun 7, 2024
1 parent 0e3ecb2 commit 7f099fb
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 37 deletions.
16 changes: 8 additions & 8 deletions pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
"mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK),
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
}


Expand Down
43 changes: 25 additions & 18 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,10 +856,7 @@ def _intervention_setter(
] # batch_size

def hook_callback(model, args, kwargs, output=None):
if (
self._skip_forward
and state.setter_timestep <= 0
):
if self._skip_forward and state.setter_timestep <= 0:
state.setter_timestep += 1
return

Expand All @@ -881,7 +878,8 @@ def hook_callback(model, args, kwargs, output=None):
else [
(
[0]
if timestep_selector != None and timestep_selector[key_i](
if timestep_selector != None
and timestep_selector[key_i](
state.setter_timestep, output[i]
)
else None
Expand Down Expand Up @@ -1054,11 +1052,11 @@ def _wait_for_forward_with_parallel_intervention(
group_get_handlers.remove()
else:
# simply patch in the ones passed in
self.activations = activations_sources
for _, passed_in_key in enumerate(self.activations):
for passed_in_key, v in activations_sources.items():
assert (
passed_in_key in self.sorted_keys
), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}"
self.activations[passed_in_key] = torch.clone(v)

# in parallel mode, we swap cached activations all into
# base at once
Expand Down Expand Up @@ -1094,17 +1092,25 @@ def _wait_for_forward_with_serial_intervention(
if sources[group_id] is None:
continue # smart jump for advance usage only

group_dest = "base" if group_id >= len(self._intervention_group) - 1 else f"source_{group_id+1}"
group_key = f'source_{group_id}->{group_dest}'
group_dest = (
"base"
if group_id >= len(self._intervention_group) - 1
else f"source_{group_id+1}"
)
group_key = f"source_{group_id}->{group_dest}"
unit_locations_source = unit_locations[group_key][0]
unit_locations_base = unit_locations[group_key][1]

if activations_sources != None:
for key in keys:
self.activations[key] = activations_sources[key]
else:
keys_with_source = [k for i, k in enumerate(keys) if unit_locations_source[i] != None]
get_handlers = self._intervention_getter(keys_with_source, unit_locations_source)
keys_with_source = [
k for i, k in enumerate(keys) if unit_locations_source[i] != None
]
get_handlers = self._intervention_getter(
keys_with_source, unit_locations_source
)

# call once per group. each intervention is by its own group by default
if activations_sources is None:
Expand Down Expand Up @@ -1402,11 +1408,11 @@ def forward(

self._output_validation()

collected_activations = []
collected_activations = {}
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(self.interventions[key][0], CollectIntervention):
collected_activations += self.activations[key]
collected_activations[key] = self.activations[key]

except Exception as e:
raise e
Expand Down Expand Up @@ -1439,15 +1445,16 @@ def generate(
self,
base,
sources: Optional[List] = None,
source_representations: Optional[Dict] = None,
intervene_on_prompt: bool = True,
unit_locations: Optional[Dict] = None,
timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None,
intervene_on_prompt: bool = True,
source_representations: Optional[Dict] = None,
subspaces: Optional[List] = None,
output_original_output: Optional[bool] = False,
**kwargs,
) -> Tuple[
ModelOutput | Tuple[ModelOutput | None, List[torch.Tensor]] | None, ModelOutput
Optional[ModelOutput | Tuple[Optional[ModelOutput], Dict[str, torch.Tensor]]],
ModelOutput,
]:
"""
Intervenable generation function that serves a
Expand Down Expand Up @@ -1532,11 +1539,11 @@ def generate(
# run intervened generate
counterfactual_outputs = self.model.generate(**base, **kwargs)

collected_activations = []
collected_activations = {}
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(self.interventions[key][0], CollectIntervention):
collected_activations += self.activations[key]
collected_activations[key] = self.activations[key]
except Exception as e:
raise e
finally:
Expand Down
72 changes: 61 additions & 11 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,18 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 5,
"id": "128be2dd-f089-4291-bfc5-7002d031b1e9",
"metadata": {},
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n"
"loaded GPT2 model gpt2\n",
"torch.Size([12, 14, 14])\n"
]
}
],
Expand All @@ -186,10 +189,11 @@
" \"intervention\": pv.CollectIntervention()}, model=gpt2)\n",
"\n",
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
"collected_attn_w = pv_gpt2(\n",
"(_, collected_attn_w), _ = pv_gpt2(\n",
" base = tokenizer(base, return_tensors=\"pt\"\n",
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
")[0][-1][0]"
")\n",
"print(collected_attn_w[0].shape)"
]
},
{
Expand Down Expand Up @@ -1753,15 +1757,31 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 1,
"id": "61cd8fc9",
"metadata": {},
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n",
"loaded GPT2 model gpt2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/juice/scr/nathangk/text-intervention/pyvene/pyvene/models/intervenable_base.py:796: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" cached_activations = torch.tensor(self.activations[key])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"True True\n"
]
}
Expand Down Expand Up @@ -2017,9 +2037,11 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 2,
"id": "acce6e8f",
"metadata": {},
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stderr",
Expand All @@ -2028,6 +2050,34 @@
"You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "517de63768da4f7f8f58e5018c6f75f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/308M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "378463b4bd174067b6269d64a1ddf1fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -2676,7 +2726,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.7"
},
"toc-autonumbering": true,
"toc-showcode": false,
Expand Down
Loading

0 comments on commit 7f099fb

Please sign in to comment.