Skip to content

Commit dcd0ed9

Browse files
committed
Fix smodel initialization
Fix kwargs use for mpl in contributivity S-model initialization could fail if the confusion matrix has not the right shape, which can be the case if some labels are not included in the dataset of a partner. btw I noticed that smodel can only work with datasets with 10 labels, that's only cifar and mnist. I opened an issue about that Signed-off-by: arthurPignet <[email protected]>
1 parent 2b476d5 commit dcd0ed9

File tree

4 files changed

+5
-9
lines changed

4 files changed

+5
-9
lines changed

.coverage

0 Bytes
Binary file not shown.

mplc/contributivity.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,12 +1124,7 @@ def compute_relative_perf_matrix(self):
11241124
def statistcal_distances_via_smodel(self):
11251125

11261126
start = timer()
1127-
try:
1128-
mpl_pretrain = self.scenario.mpl.pretrain_epochs
1129-
except AttributeError:
1130-
mpl_pretrain = 2
1131-
1132-
mpl = fast_mpl.FastFedAvgSmodel(self.scenario, mpl_pretrain)
1127+
mpl = fast_mpl.FastFedAvgSmodel(self.scenario, **self.scenario.mpl_kwargs)
11331128
mpl.fit()
11341129
cross_entropy = tf.keras.metrics.CategoricalCrossentropy()
11351130
self.contributivity_scores = {'Kullback Leiber divergence': [0 for _ in mpl.partners_list],

mplc/multi_partner_learning/basic_mpl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def fit(self):
544544
for p in self.partners_list:
545545
confusion = confusion_matrix(np.argmax(p.y_train, axis=1),
546546
np.argmax(pretrain_model.predict(p.x_train), axis=1),
547-
normalize='pred')
547+
normalize='pred',
548+
labels=list(range(10)))
548549
p.noise_layer_weights = [np.log(confusion.T + 1e-8)]
549550
self.model_weights[:-1] = self.pretrain_mpl.model_weights[:-1]
550551
else:

mplc/multi_partner_learning/fast_mpl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def fit_minibatch(model, partners_minibatches, partners_optimizers, partners_wei
376376
for p in self.partners_list:
377377
confusion = confusion_matrix(np.argmax(p.y_train, axis=1),
378378
np.argmax(self.model.predict(p.x_train), axis=1),
379-
normalize='pred')
379+
normalize='pred', labels=list(range(10)))
380380
p.noise_layer_weights = [np.log(confusion.T + 1e-8)]
381381
else:
382382
for p in self.partners_list:
@@ -549,7 +549,7 @@ def fit_epoch(model, train_dataset, partners_grads, smodel_list, global_grad, ag
549549
for p in self.partners_list:
550550
confusion = confusion_matrix(np.argmax(p.y_train, axis=1),
551551
np.argmax(self.model.predict(p.x_train), axis=1),
552-
normalize='pred')
552+
normalize='pred', labels=list(range(10)))
553553
p.noise_layer_weights = [np.log(confusion.T + 1e-8)]
554554
else:
555555
for p in self.partners_list:

0 commit comments

Comments
 (0)