|
| 1 | +import networkx as nx |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import numpy as np |
| 4 | +from sentence_transformers import SentenceTransformer |
| 5 | +from sklearn.metrics.pairwise import cosine_similarity |
| 6 | +import pandas as pd |
| 7 | +from matplotlib import cm |
| 8 | +from matplotlib.lines import Line2D |
| 9 | + |
| 10 | +# Load paragraphs from CSV |
| 11 | +path = "src/agents/utils/synthetic_conversations/" |
| 12 | +df = pd.read_csv(path+"prompts_importance.tsv", delimiter="\t") # Replace with your actual file name |
| 13 | +print(df.columns) |
| 14 | +df["response"] = df["response"].astype(str).str.replace("$$", "", regex=False) |
| 15 | +df["response"] = df["response"].astype(str).str.replace("\\", "", regex=False) |
| 16 | +paragraphs = df["response"].tolist() |
| 17 | +messages = df["message"].tolist() |
| 18 | +prompts = df["prompt"].tolist() |
| 19 | +missing_prompts = df["prompt_missing"].tolist() |
| 20 | +print(f"Loaded {len(paragraphs)} paragraphs, {len(messages)} messages, {len(prompts)} prompts, and {len(missing_prompts)} missing prompts") |
| 21 | + |
| 22 | +# Load embedding model |
| 23 | +model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| 24 | + |
| 25 | +# Compute embeddings |
| 26 | +embeddings = model.encode(paragraphs, convert_to_numpy=True) |
| 27 | + |
| 28 | +# Compute similarity matrix |
| 29 | +similarity_matrix = cosine_similarity(embeddings) |
| 30 | + |
| 31 | +# Create graph |
| 32 | +G = nx.Graph() |
| 33 | + |
| 34 | +# Create a mapping of messages to colours |
| 35 | +message_to_color = {msg: cm.viridis(i / len(set(messages))) for i, msg in enumerate(set(messages))} |
| 36 | + |
| 37 | +# Add nodes with message-based colours |
| 38 | +node_colors = [] |
| 39 | +for i, paragraph in enumerate(paragraphs): |
| 40 | + msg = messages[i] |
| 41 | + color = message_to_color[msg] # Get colour based on message |
| 42 | + G.add_node(i, text=paragraph, message=msg, color=color, prompt=prompts[i], missing_prompt=missing_prompts[i]) |
| 43 | + node_colors.append(color) # Add node colour for visualization |
| 44 | + |
| 45 | +# Define a similarity threshold for edges |
| 46 | +threshold = 0.5 |
| 47 | +for i in range(len(paragraphs)): |
| 48 | + for j in range(i + 1, len(paragraphs)): |
| 49 | + if similarity_matrix[i, j] > threshold: |
| 50 | + G.add_edge(i, j, weight=similarity_matrix[i, j]) |
| 51 | + |
| 52 | +# Draw graph |
| 53 | +fig, ax = plt.subplots(figsize=(12, 6)) |
| 54 | +pos = nx.spring_layout(G) # Positioning of nodes |
| 55 | +nx.draw(G, pos, with_labels=False, node_color=node_colors, edge_color="white", ax=ax) |
| 56 | + |
| 57 | +# Create annotation for hover effect |
| 58 | +hover_text = ax.text(0.5, -0.1, "", transform=ax.transAxes, ha="center", va="top", fontsize=10, wrap=True) |
| 59 | +hover_text.set_visible(False) |
| 60 | + |
| 61 | +# Function to update hover text and wrap it |
| 62 | +def update_hover_text(ind): |
| 63 | + node_idx = ind["ind"][0] |
| 64 | + node_pos = pos[node_idx] |
| 65 | + hover_text.set_position((0.5, -0.05)) # Position the text box at the bottom |
| 66 | + hover_text.set_text("Message: "+ G.nodes[node_idx]["message"]+ "\nResponse: "+ G.nodes[node_idx]["text"]) # Set the text |
| 67 | + hover_text.set_visible(True) |
| 68 | + plt.draw() |
| 69 | + |
| 70 | +# Mouse hover event |
| 71 | +def hover(event): |
| 72 | + if event.inaxes == ax: |
| 73 | + for i, (x, y) in pos.items(): |
| 74 | + if np.linalg.norm([x - event.xdata, y - event.ydata]) < 0.05: # Adjust hover sensitivity |
| 75 | + update_hover_text({"ind": [i]}) |
| 76 | + return |
| 77 | + hover_text.set_visible(False) # Hide text when not hovering over nodes |
| 78 | + plt.draw() |
| 79 | + |
| 80 | +# Mouse click event |
| 81 | +def on_click(event): |
| 82 | + if event.inaxes == ax: |
| 83 | + for i, (x, y) in pos.items(): |
| 84 | + if np.linalg.norm([x - event.xdata, y - event.ydata]) < 0.05: # Click sensitivity |
| 85 | + node_idx = i |
| 86 | + message = G.nodes[node_idx]["message"] |
| 87 | + prompt = G.nodes[node_idx]["prompt"] |
| 88 | + missing_prompt = G.nodes[node_idx]["missing_prompt"] |
| 89 | + text = G.nodes[node_idx]["text"] |
| 90 | + print(f"Clicked node {node_idx} \n-- Message: {message}\n-- Response: {text}\n-- Prompt: {prompt}\n-- Missing Prompt: {missing_prompt}") |
| 91 | + print("====================") |
| 92 | + |
| 93 | +# Create legend |
| 94 | +legend_handles = [] |
| 95 | +for msg, color in message_to_color.items(): |
| 96 | + legend_handles.append(Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=msg)) |
| 97 | + |
| 98 | +ax.legend(handles=legend_handles, title="Messages", bbox_to_anchor=(0.3, 0.0), loc='lower center', |
| 99 | + borderaxespad=1, ncol=1, fontsize=10, columnspacing=1, frameon=False) |
| 100 | + |
| 101 | +# Connect events |
| 102 | +fig.canvas.mpl_connect("motion_notify_event", hover) |
| 103 | +fig.canvas.mpl_connect("button_press_event", on_click) |
| 104 | + |
| 105 | +plt.subplots_adjust(bottom=0.2) # Add space for the bottom bar |
| 106 | +plt.show() |
0 commit comments