|
1 | 1 | import os |
2 | 2 | import pickle as pkl |
3 | 3 | from pathlib import Path |
4 | | -from typing import Dict, Iterable |
| 4 | +from typing import Dict |
5 | 5 |
|
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
6 | 8 | import pandas as pd |
7 | | -from sklearn.metrics import precision_score |
| 9 | +from sklearn.metrics import ( |
| 10 | + average_precision_score, |
| 11 | + precision_recall_curve, |
| 12 | +) |
| 13 | + |
| 14 | +from spras.analysis.ml import create_palette |
8 | 15 |
|
9 | 16 |
|
10 | 17 | class Evaluation: |
@@ -71,29 +78,164 @@ def load_files_from_dict(self, gold_standard_dict: Dict): |
71 | 78 | # TODO: later iteration - chose between node and edge file, or allow both |
72 | 79 |
|
73 | 80 | @staticmethod |
74 | | - def precision(file_paths: Iterable[Path], node_table: pd.DataFrame, output_file: str): |
| 81 | + def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list, dataset_file: str) -> dict: |
75 | 82 | """ |
76 | | - Takes in file paths for a specific dataset and an associated gold standard node table. |
77 | | - Calculates precision for each pathway file |
78 | | - Returns output back to output_file |
79 | | - @param file_paths: file paths of pathway reconstruction algorithm outputs |
80 | | - @param node_table: the gold standard nodes |
81 | | - @param output_file: the filename to save the precision of each pathway |
| 83 | + Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files. |
| 84 | + A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset |
| 85 | +
|
| 86 | + 1. Prepare a set of default nodes (from the interactome and gold standard) with frequency 0, |
| 87 | + ensuring all nodes are represented in the ensemble. |
| 88 | + - Answers "Did the algorithm(s) select the correct nodes from the entire network?" |
| 89 | + - It measures whether the algorithm(s) can distinguish relevant gold standard nodes |
| 90 | + from the full 'universe' of possible nodes present in the input network. |
| 91 | + 2. For each edge ensemble file: |
| 92 | + a. Read edges and their frequencies. |
| 93 | + b. Convert edges frequencies into node-level frequencies for Node1 and Node2. |
| 94 | + c. Merge with the default node set and group by node, taking the maximum frequency per node. |
| 95 | + 3. Store the resulting node-frequency ensemble under the corresponding ensemble source (label). |
| 96 | +
|
| 97 | + If the interactome or gold standard table is empty, a ValueError is raised. |
| 98 | +
|
| 99 | + @param node_table: dataFrame of gold standard nodes (column: NODEID) |
| 100 | + @param ensemble_files: list of file paths containing edge ensemble outputs |
| 101 | + @param dataset_file: path to the dataset file used to load the interactome |
| 102 | + @return: dictionary mapping each ensemble source to its node ensemble DataFrame |
82 | 103 | """ |
83 | | - y_true = set(node_table['NODEID']) |
84 | | - results = [] |
85 | 104 |
|
86 | | - for file in file_paths: |
87 | | - df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"]) |
88 | | - y_pred = set(df['Node1']).union(set(df['Node2'])) |
89 | | - all_nodes = y_true.union(y_pred) |
90 | | - y_true_binary = [1 if node in y_true else 0 for node in all_nodes] |
91 | | - y_pred_binary = [1 if node in y_pred else 0 for node in all_nodes] |
| 105 | + node_ensembles_dict = dict() |
| 106 | + |
| 107 | + pickle = Evaluation.from_file(dataset_file) |
| 108 | + interactome = pickle.get_interactome() |
| 109 | + |
| 110 | + if interactome.empty: |
| 111 | + raise ValueError( |
| 112 | + f"Cannot compute PR curve or generate node ensemble. Input network for dataset '{dataset_file.split('-')[0]}' is empty." |
| 113 | + ) |
| 114 | + if node_table.empty: |
| 115 | + raise ValueError( |
| 116 | + f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset '{dataset_file.split('-')[0]}' is empty." |
| 117 | + ) |
| 118 | + |
| 119 | + # set the initial default frequencies to 0 for all interactome and gold standard nodes |
| 120 | + node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'}) |
| 121 | + node1_interactome['Frequency'] = 0.0 |
| 122 | + node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'}) |
| 123 | + node2_interactome['Frequency'] = 0.0 |
| 124 | + gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'}) |
| 125 | + gs_nodes['Frequency'] = 0.0 |
| 126 | + |
| 127 | + # combine gold standard and network nodes |
| 128 | + other_nodes = pd.concat([node1_interactome, node2_interactome, gs_nodes]) |
| 129 | + |
| 130 | + for ensemble_file in ensemble_files: |
| 131 | + label = Path(ensemble_file).name.split('-')[0] |
| 132 | + ensemble_df = pd.read_table(ensemble_file, sep='\t', header=0) |
92 | 133 |
|
93 | | - # default to 0.0 if there is a divide by 0 error |
94 | | - precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0) |
| 134 | + if not ensemble_df.empty: |
| 135 | + node1 = ensemble_df[['Node1', 'Frequency']].rename(columns={'Node1': 'Node'}) |
| 136 | + node2 = ensemble_df[['Node2', 'Frequency']].rename(columns={'Node2': 'Node'}) |
| 137 | + all_nodes = pd.concat([node1, node2, other_nodes]) |
| 138 | + node_ensemble = all_nodes.groupby(['Node']).max().reset_index() |
| 139 | + else: |
| 140 | + node_ensemble = other_nodes.groupby(['Node']).max().reset_index() |
95 | 141 |
|
96 | | - results.append({"Pathway": file, "Precision": precision}) |
| 142 | + node_ensembles_dict[label] = node_ensemble |
97 | 143 |
|
98 | | - precision_df = pd.DataFrame(results) |
99 | | - precision_df.to_csv(output_file, sep="\t", index=False) |
| 144 | + return node_ensembles_dict |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str, |
| 148 | + output_file: str): |
| 149 | + """ |
| 150 | + Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard. |
| 151 | +
|
| 152 | + Takes in a dictionary containing either algorithm-specific node ensembles or an aggregated node ensemble |
| 153 | + for a given dataset, along with the corresponding gold standard node table. Computes PR curves for |
| 154 | + each ensemble and plots all curves on a single figure. |
| 155 | +
|
| 156 | + @param node_ensembles: dict of the pre-computed node_ensemble(s) |
| 157 | + @param node_table: gold standard nodes |
| 158 | + @param output_png: filename to save the precision and recall curves as a .png image |
| 159 | + @param output_file: filename to save the precision, recall, threshold values, average precision, and baseline |
| 160 | + average precision |
| 161 | + """ |
| 162 | + gold_standard_nodes = set(node_table[Evaluation.NODE_ID]) |
| 163 | + |
| 164 | + # make color palette per ensemble label name |
| 165 | + label_names = list(node_ensembles.keys()) |
| 166 | + color_palette = create_palette(label_names) |
| 167 | + |
| 168 | + plt.figure(figsize=(10, 7)) |
| 169 | + |
| 170 | + prc_dfs = [] |
| 171 | + metric_dfs = [] |
| 172 | + |
| 173 | + baseline = None |
| 174 | + |
| 175 | + for label, node_ensemble in node_ensembles.items(): |
| 176 | + if not node_ensemble.empty: |
| 177 | + y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']] |
| 178 | + y_scores = node_ensemble['Frequency'].tolist() |
| 179 | + precision, recall, thresholds = precision_recall_curve(y_true, y_scores) |
| 180 | + # avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold |
| 181 | + avg_precision = average_precision_score(y_true, y_scores) |
| 182 | + |
| 183 | + # only set baseline precision once |
| 184 | + # the same for every algorithm per dataset/goldstandard pair |
| 185 | + if baseline is None: |
| 186 | + baseline = np.sum(y_true) / len(y_true) |
| 187 | + plt.axhline(y=baseline, color="black", linestyle='--', label=f'Baseline: {baseline:.4f}') |
| 188 | + |
| 189 | + plt.plot(recall, precision, color=color_palette[label], marker='o', |
| 190 | + label=f'{label.capitalize()} (AP: {avg_precision:.4f})') |
| 191 | + |
| 192 | + # Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds |
| 193 | + # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve:~:text=Returns%3A-,precision,predictions%20with%20score%20%3E%3D%20thresholds%5Bi%5D%20and%20the%20last%20element%20is%200.,-thresholds |
| 194 | + prc_data = { |
| 195 | + 'Threshold': thresholds, |
| 196 | + 'Precision': precision[:-1], |
| 197 | + 'Recall': recall[:-1], |
| 198 | + } |
| 199 | + |
| 200 | + metric_data = { |
| 201 | + 'Average_Precision': [avg_precision], |
| 202 | + } |
| 203 | + |
| 204 | + ensemble_source = label.capitalize() if label != 'ensemble' else "Aggregated" |
| 205 | + prc_data = {'Ensemble_Source': [ensemble_source] * len(thresholds), **prc_data} |
| 206 | + metric_data = {'Ensemble_Source': [ensemble_source], **metric_data} |
| 207 | + |
| 208 | + prc_df = pd.DataFrame.from_dict(prc_data) |
| 209 | + prc_dfs.append(prc_df) |
| 210 | + metric_df = pd.DataFrame.from_dict(metric_data) |
| 211 | + metric_dfs.append(metric_df) |
| 212 | + |
| 213 | + else: |
| 214 | + raise ValueError( |
| 215 | + "Cannot compute PR curve: the ensemble network is empty." |
| 216 | + f"This should not happen unless the input network for pathway reconstruction is empty." |
| 217 | + ) |
| 218 | + |
| 219 | + if 'ensemble' not in label_names: |
| 220 | + plt.title('Precision-Recall Curve Per Algorithm Specific Ensemble') |
| 221 | + else: |
| 222 | + plt.title('Precision-Recall Curve for Aggregated Ensemble Across Algorithms') |
| 223 | + |
| 224 | + plt.xlim(0, 1) |
| 225 | + plt.ylim(0, 1) |
| 226 | + plt.xlabel('Recall') |
| 227 | + plt.ylabel('Precision') |
| 228 | + plt.legend(loc='lower left', bbox_to_anchor=(1, 0.5)) |
| 229 | + plt.grid(True) |
| 230 | + plt.savefig(output_png, bbox_inches='tight') |
| 231 | + plt.close() |
| 232 | + |
| 233 | + combined_prc_df = pd.concat(prc_dfs, ignore_index=True) |
| 234 | + combined_metrics_df = pd.concat(metric_dfs, ignore_index=True) |
| 235 | + combined_metrics_df["Baseline"] = baseline |
| 236 | + |
| 237 | + # merge dfs and NaN out metric values except for first row of each Ensemble_Source |
| 238 | + complete_df = combined_prc_df.merge(combined_metrics_df, on="Ensemble_Source", how="left") |
| 239 | + not_last_rows = complete_df.duplicated(subset="Ensemble_Source", keep='first') |
| 240 | + complete_df.loc[not_last_rows, ["Average_Precision", "Baseline"]] = None |
| 241 | + complete_df.to_csv(output_file, index=False, sep="\t") |
0 commit comments