-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathplot.py
62 lines (49 loc) · 1.86 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
from wgtda import flatten_gene_list
interactions = pd.read_csv("output/interactions.csv")
interactions["gene_set"] = interactions["vertices_set"].apply(flatten_gene_list)
# Remove rows where 'gene_set' contains only one gene
df = interactions[interactions["gene_set"].apply(len) > 1]
# Function to create and save the graph
def create_and_save_network_graph(df, betti_number):
# Filter DataFrame for the specified Betti number
filtered_df = df[df["betti_number"] == betti_number]
# Create a NetworkX graph
G = nx.Graph()
# Add nodes and edges to the graph
for gene_list in filtered_df["gene_set"]:
for i in range(len(gene_list)):
for j in range(i + 1, len(gene_list)):
gene1, gene2 = gene_list[i], gene_list[j]
if G.has_edge(gene1, gene2):
G[gene1][gene2]["weight"] += 0.5
else:
G.add_edge(gene1, gene2, weight=1)
# Function to save the graph
def save_graph(G, title="Gene Interaction Network (Betti {})".format(betti_number)):
plt.figure(figsize=(12, 12))
pos = nx.spring_layout(G)
# Get edge weights for width
edges = G.edges(data=True)
edge_widths = [data["weight"] for _, _, data in edges]
# Draw the graph with edge widths
nx.draw(
G,
pos,
with_labels=True,
node_size=50,
font_size=10,
font_weight="bold",
width=edge_widths,
edge_color="gray",
)
plt.title(title)
plt.savefig(f"output/network_graphs/{title}.png")
plt.close()
# Save the graph
save_graph(G)
# Create and save graphs for Betti numbers 1 and 2
create_and_save_network_graph(df, 1)
create_and_save_network_graph(df, 2)