Skip to content

Commit 85a3184

Browse files
authored
Merge pull request Reed-CompBio#212 from ntalluri/param-tuning-ensembling-2.0
Param tuning: ensembling (version 2 but all the same code as version 1)
2 parents 4f6fc5e + b6a86df commit 85a3184

15 files changed

+496
-35
lines changed

Snakefile

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ def make_final_input(wildcards):
103103
final_input.extend(expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-jaccard-heatmap.png',out_dir=out_dir,sep=SEP,dataset=dataset_labels,algorithm=algorithms))
104104

105105
if _config.config.analysis_include_evaluation:
106-
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-evaluation.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs,algorithm_params=algorithms_with_params))
107-
106+
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
107+
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
108+
if _config.config.analysis_include_evaluation_aggregate_algo:
109+
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.png',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
110+
final_input.extend(expand('{out_dir}{sep}{dataset_gold_standard_pair}-eval{sep}pr-curve-ensemble-nodes-per-algorithm.txt',out_dir=out_dir,sep=SEP,dataset_gold_standard_pair=dataset_gold_standard_pairs))
108111
if len(final_input) == 0:
109112
# No analysis added yet, so add reconstruction output files if they exist.
110113
# (if analysis is specified, these should be implicitly run).
@@ -397,22 +400,55 @@ def get_gold_standard_pickle_file(wildcards):
397400
parts = wildcards.dataset_gold_standard_pairs.split('-')
398401
gs = parts[1]
399402
return SEP.join([out_dir, f'{gs}-merged.pickle'])
400-
403+
401404
# Returns the dataset corresponding to the gold standard pair
402405
def get_dataset_label(wildcards):
403406
parts = wildcards.dataset_gold_standard_pairs.split('-')
404407
dataset = parts[0]
405408
return dataset
406409

407-
# Run evaluation code for a specific dataset's pathway outputs against its paired gold standard
408-
rule evaluation:
410+
# Return the dataset pickle file for a specific dataset
411+
def get_dataset_pickle_file(wildcards):
412+
dataset_label = get_dataset_label(wildcards)
413+
return SEP.join([out_dir, f'{dataset_label}-merged.pickle'])
414+
415+
# Returns ensemble file for each dataset
416+
def collect_ensemble_per_dataset(wildcards):
417+
dataset_label = get_dataset_label(wildcards)
418+
return expand('{out_dir}{sep}{dataset}-ml{sep}ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label)
419+
420+
# Run precision-recall curves for each ensemble pathway within a dataset evaluated against its corresponding gold standard
421+
rule evaluation_ensemble_pr_curve:
409422
input:
410423
gold_standard_file = get_gold_standard_pickle_file,
411-
pathways = expand('{out_dir}{sep}{dataset_label}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, dataset_label=get_dataset_label),
412-
output: eval_file = SEP.join([out_dir, "{dataset_gold_standard_pairs}-evaluation.txt"])
424+
dataset_file = get_dataset_pickle_file,
425+
ensemble_file = collect_ensemble_per_dataset
426+
output:
427+
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.png']),
428+
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.txt']),
429+
run:
430+
node_table = Evaluation.from_file(input.gold_standard_file).node_table
431+
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file)
432+
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file)
433+
434+
# Returns list of algorithm specific ensemble files per dataset
435+
def collect_ensemble_per_algo_per_dataset(wildcards):
436+
dataset_label = get_dataset_label(wildcards)
437+
return expand('{out_dir}{sep}{dataset}-ml{sep}{algorithm}-ensemble-pathway.txt', out_dir=out_dir, sep=SEP, dataset=dataset_label, algorithm=algorithms)
438+
439+
# Run precision-recall curves for each algorithm's ensemble pathway within a dataset evaluated against its corresponding gold standard
440+
rule evaluation_per_algo_ensemble_pr_curve:
441+
input:
442+
gold_standard_file = get_gold_standard_pickle_file,
443+
dataset_file = get_dataset_pickle_file,
444+
ensemble_files = collect_ensemble_per_algo_per_dataset
445+
output:
446+
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.png']),
447+
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.txt']),
413448
run:
414449
node_table = Evaluation.from_file(input.gold_standard_file).node_table
415-
Evaluation.precision(input.pathways, node_table, output.eval_file)
450+
node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file)
451+
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file)
416452

417453
# Remove the output directory
418454
rule clean:

config/egfr.yaml

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
# The length of the hash used to identify a parameter combination
22
hash_length: 7
33

4-
# If true, use Singularity instead of Docker
5-
# Singularity support is only available on Unix
6-
singularity: false
4+
# Specify the container framework used by each PRM wrapper. Valid options include:
5+
# - docker (default if not specified)
6+
# - singularity -- Also known as apptainer, useful in HPC/HTC environments where docker isn't allowed
7+
# - dsub -- experimental with limited support, used for running on Google Cloud
8+
container_framework: docker
9+
10+
# Only used if container_framework is set to singularity, this will unpack the singularity containers
11+
# to the local filesystem. This is useful when PRM containers need to run inside another container,
12+
# such as would be the case in an HTCondor/OSPool environment.
13+
# NOTE: This unpacks singularity containers to the local filesystem, which will take up space in a way
14+
# that persists after the workflow is complete. To clean up the unpacked containers, the user must
15+
# manually delete them.
16+
unpack_singularity: false
17+
18+
# Allow the user to configure which container registry containers should be pulled from
19+
# Note that this assumes container names are consistent across registries, and that the
20+
# registry being passed doesn't require authentication for pull actions
21+
container_registry:
22+
base_url: docker.io
23+
# The owner or project of the registry
24+
# For example, "reedcompbio" if the image is available as docker.io/reedcompbio/allpairs
25+
owner: reedcompbio
726

827
algorithms:
928
- name: pathlinker
@@ -13,6 +32,7 @@ algorithms:
1332
k:
1433
- 10
1534
- 20
35+
- 70
1636
- name: omicsintegrator1
1737
params:
1838
include: true
@@ -55,6 +75,16 @@ algorithms:
5575
- 3
5676
rand_restarts:
5777
- 10
78+
run2:
79+
local_search:
80+
- "No"
81+
max_path_length:
82+
- 2
83+
rand_restarts:
84+
- 10
85+
- name: allpairs
86+
params:
87+
include: true
5888
- name: domino
5989
params:
6090
include: true
@@ -63,6 +93,24 @@ algorithms:
6393
- 0.3
6494
module_threshold:
6595
- 0.05
96+
- name: mincostflow
97+
params:
98+
include: true
99+
run1:
100+
capacity:
101+
- 15
102+
flow:
103+
- 80
104+
run2:
105+
capacity:
106+
- 1
107+
flow:
108+
- 6
109+
run3:
110+
capacity:
111+
- 5
112+
flow:
113+
- 60
66114
datasets:
67115
- data_dir: input
68116
edge_files:
@@ -71,6 +119,13 @@ datasets:
71119
node_files:
72120
- tps-egfr-prizes.txt
73121
other_files: []
122+
gold_standards:
123+
- label: gs_egfr
124+
node_files:
125+
- gs-egfr.txt
126+
data_dir: input
127+
dataset_labels:
128+
- tps_egfr
74129
reconstruction_settings:
75130
locations:
76131
reconstruction_dir: output/egfr
@@ -81,6 +136,9 @@ analysis:
81136
summary:
82137
include: true
83138
ml:
84-
include: false
139+
include: true
140+
aggregate_per_algorithm: true
141+
labels: true
85142
evaluation:
86-
include: false
143+
include: true
144+
aggregate_per_algorithm: true

spras/evaluation.py

Lines changed: 164 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import os
22
import pickle as pkl
33
from pathlib import Path
4-
from typing import Dict, Iterable
4+
from typing import Dict
55

6+
import matplotlib.pyplot as plt
7+
import numpy as np
68
import pandas as pd
7-
from sklearn.metrics import precision_score
9+
from sklearn.metrics import (
10+
average_precision_score,
11+
precision_recall_curve,
12+
)
13+
14+
from spras.analysis.ml import create_palette
815

916

1017
class Evaluation:
@@ -71,29 +78,164 @@ def load_files_from_dict(self, gold_standard_dict: Dict):
7178
# TODO: later iteration - chose between node and edge file, or allow both
7279

7380
@staticmethod
74-
def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str):
81+
def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list, dataset_file: str) -> dict:
7582
"""
76-
Takes in file paths for a specific dataset and an associated gold standard node table.
77-
Calculates precision for each pathway file
78-
Returns output back to output_file
79-
@param file_paths: file paths of pathway reconstruction algorithm outputs
80-
@param node_table: the gold standard nodes
81-
@param output_file: the filename to save the precision of each pathway
83+
Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files.
84+
A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset
85+
86+
1. Prepare a set of default nodes (from the interactome and gold standard) with frequency 0,
87+
ensuring all nodes are represented in the ensemble.
88+
- Answers "Did the algorithm(s) select the correct nodes from the entire network?"
89+
- It measures whether the algorithm(s) can distinguish relevant gold standard nodes
90+
from the full 'universe' of possible nodes present in the input network.
91+
2. For each edge ensemble file:
92+
a. Read edges and their frequencies.
93+
b. Convert edges frequencies into node-level frequencies for Node1 and Node2.
94+
c. Merge with the default node set and group by node, taking the maximum frequency per node.
95+
3. Store the resulting node-frequency ensemble under the corresponding ensemble source (label).
96+
97+
If the interactome or gold standard table is empty, a ValueError is raised.
98+
99+
@param node_table: dataFrame of gold standard nodes (column: NODEID)
100+
@param ensemble_files: list of file paths containing edge ensemble outputs
101+
@param dataset_file: path to the dataset file used to load the interactome
102+
@return: dictionary mapping each ensemble source to its node ensemble DataFrame
82103
"""
83-
y_true = set(node_table['NODEID'])
84-
results = []
85104

86-
for file in file_paths:
87-
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
88-
y_pred = set(df['Node1']).union(set(df['Node2']))
89-
all_nodes = y_true.union(y_pred)
90-
y_true_binary = [1 if node in y_true else 0 for node in all_nodes]
91-
y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes]
105+
node_ensembles_dict = dict()
106+
107+
pickle = Evaluation.from_file(dataset_file)
108+
interactome = pickle.get_interactome()
109+
110+
if interactome.empty:
111+
raise ValueError(
112+
f"Cannot compute PR curve or generate node ensemble. Input network for dataset '{dataset_file.split('-')[0]}' is empty."
113+
)
114+
if node_table.empty:
115+
raise ValueError(
116+
f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset '{dataset_file.split('-')[0]}' is empty."
117+
)
118+
119+
# set the initial default frequencies to 0 for all interactome and gold standard nodes
120+
node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'})
121+
node1_interactome['Frequency'] = 0.0
122+
node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'})
123+
node2_interactome['Frequency'] = 0.0
124+
gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'})
125+
gs_nodes['Frequency'] = 0.0
126+
127+
# combine gold standard and network nodes
128+
other_nodes = pd.concat([node1_interactome, node2_interactome, gs_nodes])
129+
130+
for ensemble_file in ensemble_files:
131+
label = Path(ensemble_file).name.split('-')[0]
132+
ensemble_df = pd.read_table(ensemble_file, sep='\t', header=0)
92133

93-
# default to 0.0 if there is a divide by 0 error
94-
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0)
134+
if not ensemble_df.empty:
135+
node1 = ensemble_df[['Node1', 'Frequency']].rename(columns={'Node1': 'Node'})
136+
node2 = ensemble_df[['Node2', 'Frequency']].rename(columns={'Node2': 'Node'})
137+
all_nodes = pd.concat([node1, node2, other_nodes])
138+
node_ensemble = all_nodes.groupby(['Node']).max().reset_index()
139+
else:
140+
node_ensemble = other_nodes.groupby(['Node']).max().reset_index()
95141

96-
results.append({"Pathway": file, "Precision": precision})
142+
node_ensembles_dict[label] = node_ensemble
97143

98-
precision_df = pd.DataFrame(results)
99-
precision_df.to_csv(output_file, sep="\t", index=False)
144+
return node_ensembles_dict
145+
146+
@staticmethod
147+
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str,
148+
output_file: str):
149+
"""
150+
Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard.
151+
152+
Takes in a dictionary containing either algorithm-specific node ensembles or an aggregated node ensemble
153+
for a given dataset, along with the corresponding gold standard node table. Computes PR curves for
154+
each ensemble and plots all curves on a single figure.
155+
156+
@param node_ensembles: dict of the pre-computed node_ensemble(s)
157+
@param node_table: gold standard nodes
158+
@param output_png: filename to save the precision and recall curves as a .png image
159+
@param output_file: filename to save the precision, recall, threshold values, average precision, and baseline
160+
average precision
161+
"""
162+
gold_standard_nodes = set(node_table[Evaluation.NODE_ID])
163+
164+
# make color palette per ensemble label name
165+
label_names = list(node_ensembles.keys())
166+
color_palette = create_palette(label_names)
167+
168+
plt.figure(figsize=(10, 7))
169+
170+
prc_dfs = []
171+
metric_dfs = []
172+
173+
baseline = None
174+
175+
for label, node_ensemble in node_ensembles.items():
176+
if not node_ensemble.empty:
177+
y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']]
178+
y_scores = node_ensemble['Frequency'].tolist()
179+
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
180+
# avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold
181+
avg_precision = average_precision_score(y_true, y_scores)
182+
183+
# only set baseline precision once
184+
# the same for every algorithm per dataset/goldstandard pair
185+
if baseline is None:
186+
baseline = np.sum(y_true) / len(y_true)
187+
plt.axhline(y=baseline, color="black", linestyle='--', label=f'Baseline: {baseline:.4f}')
188+
189+
plt.plot(recall, precision, color=color_palette[label], marker='o',
190+
label=f'{label.capitalize()} (AP: {avg_precision:.4f})')
191+
192+
# Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds
193+
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve:~:text=Returns%3A-,precision,predictions%20with%20score%20%3E%3D%20thresholds%5Bi%5D%20and%20the%20last%20element%20is%200.,-thresholds
194+
prc_data = {
195+
'Threshold': thresholds,
196+
'Precision': precision[:-1],
197+
'Recall': recall[:-1],
198+
}
199+
200+
metric_data = {
201+
'Average_Precision': [avg_precision],
202+
}
203+
204+
ensemble_source = label.capitalize() if label != 'ensemble' else "Aggregated"
205+
prc_data = {'Ensemble_Source': [ensemble_source] * len(thresholds), **prc_data}
206+
metric_data = {'Ensemble_Source': [ensemble_source], **metric_data}
207+
208+
prc_df = pd.DataFrame.from_dict(prc_data)
209+
prc_dfs.append(prc_df)
210+
metric_df = pd.DataFrame.from_dict(metric_data)
211+
metric_dfs.append(metric_df)
212+
213+
else:
214+
raise ValueError(
215+
"Cannot compute PR curve: the ensemble network is empty."
216+
f"This should not happen unless the input network for pathway reconstruction is empty."
217+
)
218+
219+
if 'ensemble' not in label_names:
220+
plt.title('Precision-Recall Curve Per Algorithm Specific Ensemble')
221+
else:
222+
plt.title('Precision-Recall Curve for Aggregated Ensemble Across Algorithms')
223+
224+
plt.xlim(0, 1)
225+
plt.ylim(0, 1)
226+
plt.xlabel('Recall')
227+
plt.ylabel('Precision')
228+
plt.legend(loc='lower left', bbox_to_anchor=(1, 0.5))
229+
plt.grid(True)
230+
plt.savefig(output_png, bbox_inches='tight')
231+
plt.close()
232+
233+
combined_prc_df = pd.concat(prc_dfs, ignore_index=True)
234+
combined_metrics_df = pd.concat(metric_dfs, ignore_index=True)
235+
combined_metrics_df["Baseline"] = baseline
236+
237+
# merge dfs and NaN out metric values except for first row of each Ensemble_Source
238+
complete_df = combined_prc_df.merge(combined_metrics_df, on="Ensemble_Source", how="left")
239+
not_last_rows = complete_df.duplicated(subset="Ensemble_Source", keep='first')
240+
complete_df.loc[not_last_rows, ["Average_Precision", "Baseline"]] = None
241+
complete_df.to_csv(output_file, index=False, sep="\t")

0 commit comments

Comments
 (0)