Skip to content

Commit

Permalink
Add point highlight
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 19, 2024
1 parent 4e6ab50 commit 0907de6
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions examples/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import subprocess
import re
import tempfile

import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -52,12 +53,14 @@ def generate_data(type, N=1000, random_state=None):


def embed(data,method):
input_file = 'tapkee_input_data'
output_file = 'tapkee_output_data'
np.savetxt(input_file, data.T,delimiter=',')
input_file = tempfile.NamedTemporaryFile(prefix='tapkee_input')
output_file = tempfile.NamedTemporaryFile(prefix='tapkee_output')
np.savetxt(input_file.name, data.T,delimiter=',')
tapkee_binary = 'bin/tapkee'
runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --debug --verbose --transpose-output --benchmark' % (tapkee_binary, input_file, output_file, method)
print('-- To reproduce this use the following command', runner_string)
runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --debug --verbose --transpose-output --benchmark' % (
tapkee_binary, input_file.name, output_file.name, method
)
print('-- To reproduce this use the following command `{}`'.format(runner_string))
process = subprocess.run(runner_string, shell=True, capture_output=True, text=True)
print(process.stderr)
if process.returncode != 0:
Expand All @@ -70,26 +73,55 @@ def embed(data,method):


embedded_data = np.loadtxt(output_file, delimiter=',')
os.remove(input_file)
os.remove(output_file)
return embedded_data, used_method

def plot(data, embedded_data, colors='m', method=None):
fig = plt.figure()
fig.set_facecolor('white')

ax = fig.add_subplot(121, projection='3d')
ax.scatter(data[0], data[1], data[2], c=colors, cmap=plt.cm.Spectral, s=5)
ax_original = fig.add_subplot(121, projection='3d')
scatter_original = ax_original.scatter(data[0], data[1], data[2], c=colors, cmap=plt.cm.Spectral, s=5, picker=True)
plt.axis('tight')
plt.axis('off')
plt.title('Original', fontsize=9)

ax = fig.add_subplot(122)
ax.scatter(embedded_data[0], embedded_data[1], c=colors, cmap=plt.cm.Spectral, s=5)
ax_embedding = fig.add_subplot(122)
scatter_embedding = ax_embedding.scatter(embedded_data[0], embedded_data[1], c=colors, cmap=plt.cm.Spectral, s=5, picker=True)
plt.axis('tight')
plt.axis('off')
plt.title('Embedding' + (' with ' + method) if method else '', fontsize=9, wrap=True)

highlighted_points = [] # To store highlighted points

# Function to highlight points on both plots
def highlight(index):
# Reset previous highlighted points
for point in highlighted_points:
point.remove()
highlighted_points.clear()

# Highlight the current point on both scatter plots
point1 = ax_original.scatter([data[0][index]], [data[1][index]], [data[2][index]], color='white', s=25, edgecolor='black', zorder=3)
point2 = ax_embedding.scatter([embedded_data[0][index]], [embedded_data[1][index]], color='white', s=25, edgecolor='black', zorder=3)
highlighted_points.append(point1)
highlighted_points.append(point2)
fig.canvas.draw_idle()

# Event handler for mouse motion
def on_hover(event):
if event.inaxes == ax_original:
cont, ind = scatter_original.contains(event)
elif event.inaxes == ax_embedding:
cont, ind = scatter_embedding.contains(event)
else:
return

if cont:
index = ind['ind'][0]
highlight(index)

fig.canvas.mpl_connect('motion_notify_event', on_hover)

plt.show()

if __name__ == "__main__":
Expand Down

0 comments on commit 0907de6

Please sign in to comment.