Skip to content

Commit

Permalink
start interpreting using calf
Browse files Browse the repository at this point in the history
  • Loading branch information
khairulislam committed Nov 30, 2024
1 parent f95d764 commit e545477
Show file tree
Hide file tree
Showing 29 changed files with 44,393 additions and 7,895 deletions.
10 changes: 9 additions & 1 deletion exp/exp_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,16 @@ def run_classifier(self, dataloader, name):
results.extend(rows)
if results_df.shape[0]>0:
min_batch_index = results_df['batch_index'].max() + 1
else:
min_batch_index = 0
else:
min_batch_index = 0
# create and write header row if the file doesn't exists or it is to be overwritten
result_file = open(batch_filename, 'w', newline='')
writer = csv.writer(result_file)
writer.writerow(results[0])
else:
min_batch_index = 0

attrs = []
if min_batch_index>0:
Expand All @@ -165,7 +169,11 @@ def run_classifier(self, dataloader, name):
inputs = batch_x
# baseline must be a scaler or tuple of tensors with same dimension as input
baselines = get_baseline(inputs, mode=self.args.baseline_mode)
additional_forward_args = (padding_mask, None, None)

if self.args.model in ['CALF', 'OFA']:
additional_forward_args = None
else:
additional_forward_args = (padding_mask, None, None)

# get attributions
batch_results, batch_attr = self.evaluate(
Expand Down
3 changes: 0 additions & 3 deletions interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def main(args):

exp.load_best_model()

# Some models don't work with gradient based explainers
# explainers = ['deep_lift', 'gradient_shap', 'integrated_gradients']

interpreter = Exp_Interpret(exp, dataloader)
interpreter.interpret(dataloader)

Expand Down
3 changes: 0 additions & 3 deletions interpret_CALF.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def main(args):

exp.load_best_model()

# Some models don't work with gradient based explainers
# explainers = ['deep_lift', 'gradient_shap', 'integrated_gradients']

interpreter = Exp_Interpret(exp, dataloader)
interpreter.interpret(dataloader)

Expand Down
3 changes: 0 additions & 3 deletions interpret_TimeLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def main(args):

exp.load_best_model()

# Some models don't work with gradient based explainers
# explainers = ['deep_lift', 'gradient_shap', 'integrated_gradients']

interpreter = Exp_Interpret(exp, dataloader)
interpreter.interpret(dataloader)

Expand Down
9 changes: 9 additions & 0 deletions results/electricity_CALF/1/augmented_occlusion.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
metric,area,comp,suff
mae,0.05,7.97165,18.21702
mae,0.075,9.332975,17.384995
mae,0.1,10.038847,16.889594
mae,0.15,11.445272,15.737271
mse,0.05,6.148502,22.12033
mse,0.075,7.665615,20.257282
mse,0.1,8.547635,19.182259
mse,0.15,10.432541,16.808907
Loading

0 comments on commit e545477

Please sign in to comment.