Skip to content

Commit 3f273b6

Browse files
committed
similarity graph - responses of testbench
1 parent 0cd1d3a commit 3f273b6

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import networkx as nx
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from sentence_transformers import SentenceTransformer
5+
from sklearn.metrics.pairwise import cosine_similarity
6+
import pandas as pd
7+
from matplotlib import cm
8+
from matplotlib.lines import Line2D
9+
10+
# Load paragraphs from CSV
11+
path = "src/agents/utils/synthetic_conversations/"
12+
df = pd.read_csv(path+"prompts_importance.tsv", delimiter="\t") # Replace with your actual file name
13+
print(df.columns)
14+
df["response"] = df["response"].astype(str).str.replace("$$", "", regex=False)
15+
df["response"] = df["response"].astype(str).str.replace("\\", "", regex=False)
16+
paragraphs = df["response"].tolist()
17+
messages = df["message"].tolist()
18+
prompts = df["prompt"].tolist()
19+
missing_prompts = df["prompt_missing"].tolist()
20+
print(f"Loaded {len(paragraphs)} paragraphs, {len(messages)} messages, {len(prompts)} prompts, and {len(missing_prompts)} missing prompts")
21+
22+
# Load embedding model
23+
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24+
25+
# Compute embeddings
26+
embeddings = model.encode(paragraphs, convert_to_numpy=True)
27+
28+
# Compute similarity matrix
29+
similarity_matrix = cosine_similarity(embeddings)
30+
31+
# Create graph
32+
G = nx.Graph()
33+
34+
# Create a mapping of messages to colours
35+
message_to_color = {msg: cm.viridis(i / len(set(messages))) for i, msg in enumerate(set(messages))}
36+
37+
# Add nodes with message-based colours
38+
node_colors = []
39+
for i, paragraph in enumerate(paragraphs):
40+
msg = messages[i]
41+
color = message_to_color[msg] # Get colour based on message
42+
G.add_node(i, text=paragraph, message=msg, color=color, prompt=prompts[i], missing_prompt=missing_prompts[i])
43+
node_colors.append(color) # Add node colour for visualization
44+
45+
# Define a similarity threshold for edges
46+
threshold = 0.5
47+
for i in range(len(paragraphs)):
48+
for j in range(i + 1, len(paragraphs)):
49+
if similarity_matrix[i, j] > threshold:
50+
G.add_edge(i, j, weight=similarity_matrix[i, j])
51+
52+
# Draw graph
53+
fig, ax = plt.subplots(figsize=(12, 6))
54+
pos = nx.spring_layout(G) # Positioning of nodes
55+
nx.draw(G, pos, with_labels=False, node_color=node_colors, edge_color="white", ax=ax)
56+
57+
# Create annotation for hover effect
58+
hover_text = ax.text(0.5, -0.1, "", transform=ax.transAxes, ha="center", va="top", fontsize=10, wrap=True)
59+
hover_text.set_visible(False)
60+
61+
# Function to update hover text and wrap it
62+
def update_hover_text(ind):
63+
node_idx = ind["ind"][0]
64+
node_pos = pos[node_idx]
65+
hover_text.set_position((0.5, -0.05)) # Position the text box at the bottom
66+
hover_text.set_text("Message: "+ G.nodes[node_idx]["message"]+ "\nResponse: "+ G.nodes[node_idx]["text"]) # Set the text
67+
hover_text.set_visible(True)
68+
plt.draw()
69+
70+
# Mouse hover event
71+
def hover(event):
72+
if event.inaxes == ax:
73+
for i, (x, y) in pos.items():
74+
if np.linalg.norm([x - event.xdata, y - event.ydata]) < 0.05: # Adjust hover sensitivity
75+
update_hover_text({"ind": [i]})
76+
return
77+
hover_text.set_visible(False) # Hide text when not hovering over nodes
78+
plt.draw()
79+
80+
# Mouse click event
81+
def on_click(event):
82+
if event.inaxes == ax:
83+
for i, (x, y) in pos.items():
84+
if np.linalg.norm([x - event.xdata, y - event.ydata]) < 0.05: # Click sensitivity
85+
node_idx = i
86+
message = G.nodes[node_idx]["message"]
87+
prompt = G.nodes[node_idx]["prompt"]
88+
missing_prompt = G.nodes[node_idx]["missing_prompt"]
89+
text = G.nodes[node_idx]["text"]
90+
print(f"Clicked node {node_idx} \n-- Message: {message}\n-- Response: {text}\n-- Prompt: {prompt}\n-- Missing Prompt: {missing_prompt}")
91+
print("====================")
92+
93+
# Create legend
94+
legend_handles = []
95+
for msg, color in message_to_color.items():
96+
legend_handles.append(Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=msg))
97+
98+
ax.legend(handles=legend_handles, title="Messages", bbox_to_anchor=(0.3, 0.0), loc='lower center',
99+
borderaxespad=1, ncol=1, fontsize=10, columnspacing=1, frameon=False)
100+
101+
# Connect events
102+
fig.canvas.mpl_connect("motion_notify_event", hover)
103+
fig.canvas.mpl_connect("button_press_event", on_click)
104+
105+
plt.subplots_adjust(bottom=0.2) # Add space for the bottom bar
106+
plt.show()

0 commit comments

Comments
 (0)