Skip to content

Commit e5c5054

Browse files
memoWuTheFWasThat
authored andcommitted
allow models to be in a separate folder via models_dir argument (#129)
* models_dir argument to allow models in a separate folder * default value for models_dir to be same as before * allow environment variables and user home in models_dir
1 parent dd75299 commit e5c5054

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

src/encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def decode(self, tokens):
105105
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
106106
return text
107107

108-
def get_encoder(model_name):
109-
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
108+
def get_encoder(model_name, models_dir):
109+
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
110110
encoder = json.load(f)
111-
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
111+
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
112112
bpe_data = f.read()
113113
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
114114
return Encoder(

src/generate_unconditional_samples.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def sample_model(
1616
length=None,
1717
temperature=1,
1818
top_k=0,
19+
models_dir='models',
1920
):
2021
"""
2122
Run the sample_model
@@ -35,10 +36,13 @@ def sample_model(
3536
considered for each step (token), resulting in deterministic completions,
3637
while 40 means 40 words are considered at each step. 0 (default) is a
3738
special setting meaning no restrictions. 40 generally is a good value.
39+
:models_dir : path to parent folder containing model subfolders
40+
(i.e. contains the <model_name> folder)
3841
"""
39-
enc = encoder.get_encoder(model_name)
42+
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
43+
enc = encoder.get_encoder(model_name, models_dir)
4044
hparams = model.default_hparams()
41-
with open(os.path.join('models', model_name, 'hparams.json')) as f:
45+
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
4246
hparams.override_from_dict(json.load(f))
4347

4448
if length is None:
@@ -58,7 +62,7 @@ def sample_model(
5862
)[:, 1:]
5963

6064
saver = tf.train.Saver()
61-
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
65+
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
6266
saver.restore(sess, ckpt)
6367

6468
generated = 0

src/interactive_conditional_samples.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def interact_model(
1616
length=None,
1717
temperature=1,
1818
top_k=0,
19+
models_dir='models',
1920
):
2021
"""
2122
Interactively run the model
@@ -34,14 +35,17 @@ def interact_model(
3435
considered for each step (token), resulting in deterministic completions,
3536
while 40 means 40 words are considered at each step. 0 (default) is a
3637
special setting meaning no restrictions. 40 generally is a good value.
38+
:models_dir : path to parent folder containing model subfolders
39+
(i.e. contains the <model_name> folder)
3740
"""
41+
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
3842
if batch_size is None:
3943
batch_size = 1
4044
assert nsamples % batch_size == 0
4145

42-
enc = encoder.get_encoder(model_name)
46+
enc = encoder.get_encoder(model_name, models_dir)
4347
hparams = model.default_hparams()
44-
with open(os.path.join('models', model_name, 'hparams.json')) as f:
48+
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
4549
hparams.override_from_dict(json.load(f))
4650

4751
if length is None:
@@ -61,7 +65,7 @@ def interact_model(
6165
)
6266

6367
saver = tf.train.Saver()
64-
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
68+
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
6569
saver.restore(sess, ckpt)
6670

6771
while True:

0 commit comments

Comments
 (0)