Skip to content

Commit fbaf2d3

Browse files
committed
Move build_website to autoencoder dir
1 parent 3a66fd9 commit fbaf2d3

File tree

5 files changed

+35
-20
lines changed

5 files changed

+35
-20
lines changed

autoencoder/feature-browser/build_website.py renamed to autoencoder/build_website.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,16 @@
2020
"""
2121

2222
import logging
23+
from pathlib import Path
2324
from tqdm.auto import trange
2425
from dataclasses import dataclass
2526
import torch
2627
from tensordict import TensorDict
2728
import os
28-
import sys
2929
from math import ceil
30-
from main_page import create_main_html_page
31-
from subpages import write_alive_feature_page, write_dead_feature_page, write_ultralow_density_feature_page
30+
from feature_browser.main_page import create_main_html_page
31+
from feature_browser.subpages import write_alive_feature_page, write_dead_feature_page, write_ultralow_density_feature_page
3232

33-
sys.path.insert(1, '../')
3433
from resource_loader import ResourceLoader
3534
from utils.plotting_utils import make_activations_histogram, make_logits_histogram
3635

@@ -399,7 +398,7 @@ def write_main_page(self):
399398
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
400399
# -----------------------------------------------------------------------------
401400
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
402-
configurator = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'configurator.py')
401+
configurator = Path(__file__).parent / 'configurator.py'
403402
exec(open(configurator).read()) # overrides from command line or config file
404403
config = {k: globals()[k] for k in config_keys} # will be useful for logging
405404
# -----------------------------------------------------------------------------

autoencoder/resource_loader.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# Extend the Python path to include the transformer subdirectory for GPT class import
99
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1010
sys.path.insert(0, os.path.join(base_dir, 'transformer'))
11-
from model import GPTConfig
12-
from hooked_model import HookedGPT
11+
from model import GPTConfig, HookedGPT
1312

1413

1514
class ResourceLoader:
@@ -163,17 +162,8 @@ def select_resampling_data(self, size: int):
163162
return resampling_data
164163

165164
def load_tokenizer(self):
166-
load_meta = False
167165
meta_path = os.path.join(self.base_dir, 'transformer', 'data', self.dataset, 'meta.pkl')
168-
load_meta = os.path.exists(meta_path)
169-
if load_meta:
170-
print(f"Loading meta from {meta_path}...")
171-
with open(meta_path, 'rb') as f:
172-
meta = pickle.load(f)
173-
# TODO want to make this more general to arbitrary encoder/decoder schemes
174-
stoi, itos = meta['stoi'], meta['itos']
175-
encode = lambda s: [stoi[c] for c in s]
176-
decode = lambda l: ''.join([itos[i] for i in l])
177-
else:
178-
raise DeprecationWarning('must load from dataset dir')
179-
return encode, decode
166+
with open(meta_path, 'rb') as f:
167+
meta = pickle.load(f)
168+
169+
return meta['encode'], meta['decode']

transformer/model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,32 @@ def _init_weights(self, module):
166166
elif isinstance(module, nn.Embedding):
167167
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
168168

169+
# TODO: Remove this and only use the forward method below
170+
def forward(self, idx, targets=None):
171+
device = idx.device
172+
b, t = idx.size()
173+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175+
176+
# forward the GPT model itself
177+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179+
x = self.transformer.drop(tok_emb + pos_emb)
180+
for block in self.transformer.h:
181+
x = block(x)
182+
x = self.transformer.ln_f(x)
183+
184+
if targets is not None:
185+
# if we are given some desired targets also calculate the loss
186+
logits = self.lm_head(x)
187+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188+
else:
189+
# inference-time mini-optimization: only forward the lm_head on the very last position
190+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191+
loss = None
192+
193+
return logits, loss
194+
169195
def configure_optimizers(self, weight_decay, learning_rate, betas):
170196
# start with all of the candidate parameters
171197
param_dict = {pn: p for pn, p in self.named_parameters()}

0 commit comments

Comments
 (0)