Skip to content

Commit 0fc574a

Browse files
committed
Fix bottleneck, add settings
1 parent 93d2618 commit 0fc574a

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

keras-transfer/dogcat-bottleneck.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -18,46 +18,47 @@
1818
config.img_width = 224
1919
config.img_height = 224
2020
config.epochs = 50
21-
config.batch_size = 40
21+
config.batch_size = 40
2222

2323
top_model_weights_path = 'bottleneck.h5'
2424
train_dir = 'dogcat-data/train'
2525
validation_dir = 'dogcat-data/validation'
2626
nb_train_samples = 1000
2727
nb_validation_samples = 1000
2828

29+
2930
def save_bottlebeck_features():
3031
if os.path.exists('bottleneck_features_train.npy') and (len(sys.argv) == 1 or sys.argv[1] != "--force"):
3132
print("Using saved features, pass --force to save new features")
3233
return
3334
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
3435
train_generator = datagen.flow_from_directory(
35-
train_dir,
36-
target_size=(config.img_width, config.img_height),
37-
batch_size=config.batch_size,
38-
class_mode="binary")
36+
train_dir,
37+
target_size=(config.img_width, config.img_height),
38+
batch_size=config.batch_size,
39+
class_mode="binary")
3940

4041
val_generator = datagen.flow_from_directory(
41-
validation_dir,
42-
target_size=(config.img_width, config.img_height),
43-
batch_size=config.batch_size,
44-
class_mode="binary")
45-
42+
validation_dir,
43+
target_size=(config.img_width, config.img_height),
44+
batch_size=config.batch_size,
45+
class_mode="binary")
46+
4647
# build the VGG16 network
4748
model = VGG16(include_top=False, weights='imagenet')
48-
49+
4950
print("Predicting bottleneck training features")
5051
training_labels = []
5152
training_features = []
52-
for batch in range(5): #nb_train_samples // config.batch_size):
53+
for batch in range(5): # nb_train_samples // config.batch_size):
5354
data, labels = next(train_generator)
5455
training_labels.append(labels)
5556
training_features.append(model.predict(data))
5657
training_labels = np.concatenate(training_labels)
5758
training_features = np.concatenate(training_features)
5859
np.savez(open('bottleneck_features_train.npy', 'wb'),
59-
features=training_features, labels=training_labels)
60-
60+
features=training_features, labels=training_labels)
61+
6162
print("Predicting bottleneck validation features")
6263
validation_labels = []
6364
validation_features = []
@@ -71,7 +72,7 @@ def save_bottlebeck_features():
7172
validation_features = np.concatenate(validation_features)
7273
validation_data = np.concatenate(validation_data)
7374
np.savez(open('bottleneck_features_validation.npy', 'wb'),
74-
features=training_features, labels=training_labels, data=validation_data)
75+
features=validation_features, labels=validation_labels, data=validation_data)
7576

7677

7778
def train_top_model():
@@ -88,18 +89,20 @@ def train_top_model():
8889

8990
model.compile(optimizer='rmsprop',
9091
loss='binary_crossentropy', metrics=['accuracy'])
91-
92+
9293
class Images(Callback):
9394
def on_epoch_end(self, epoch, logs):
9495
base_model = VGG16(include_top=False, weights='imagenet')
9596
indices = np.random.randint(val_data.shape[0], size=36)
9697
test_data = val_data[indices]
97-
features = base_model.predict(np.array([preprocess_input(data) for data in test_data]))
98+
features = base_model.predict(
99+
np.array([preprocess_input(data) for data in test_data]))
98100
pred_data = model.predict(features)
99101
wandb.log({
100-
"examples": [
101-
wandb.Image(test_data[i], caption="cat" if pred_data[i] < 0.5 else "dog")
102-
for i, data in enumerate(test_data)]
102+
"examples": [
103+
wandb.Image(
104+
test_data[i], caption="cat" if pred_data[i] < 0.5 else "dog")
105+
for i, data in enumerate(test_data)]
103106
}, commit=False)
104107

105108
model.fit(X_train, y_train,

simpsons-challenge/wandb/settings

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[default]
2+
entity: mlclass
3+
project: simpsons-nov5
4+
base_url: https://api.wandb.ai

0 commit comments

Comments
 (0)