diff --git a/examples/go.py b/examples/go.py index 22dd7f4..494d211 100755 --- a/examples/go.py +++ b/examples/go.py @@ -5,6 +5,7 @@ import os import subprocess import re +import tempfile import numpy as np import matplotlib.pyplot as plt @@ -52,12 +53,14 @@ def generate_data(type, N=1000, random_state=None): def embed(data,method): - input_file = 'tapkee_input_data' - output_file = 'tapkee_output_data' - np.savetxt(input_file, data.T,delimiter=',') + input_file = tempfile.NamedTemporaryFile(prefix='tapkee_input') + output_file = tempfile.NamedTemporaryFile(prefix='tapkee_output') + np.savetxt(input_file.name, data.T,delimiter=',') tapkee_binary = 'bin/tapkee' - runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --debug --verbose --transpose-output --benchmark' % (tapkee_binary, input_file, output_file, method) - print('-- To reproduce this use the following command', runner_string) + runner_string = '%s -i %s -o %s -m %s -k 20 --precompute --debug --verbose --transpose-output --benchmark' % ( + tapkee_binary, input_file.name, output_file.name, method + ) + print('-- To reproduce this use the following command `{}`'.format(runner_string)) process = subprocess.run(runner_string, shell=True, capture_output=True, text=True) print(process.stderr) if process.returncode != 0: @@ -70,26 +73,55 @@ def embed(data,method): embedded_data = np.loadtxt(output_file, delimiter=',') - os.remove(input_file) - os.remove(output_file) return embedded_data, used_method def plot(data, embedded_data, colors='m', method=None): fig = plt.figure() fig.set_facecolor('white') - ax = fig.add_subplot(121, projection='3d') - ax.scatter(data[0], data[1], data[2], c=colors, cmap=plt.cm.Spectral, s=5) + ax_original = fig.add_subplot(121, projection='3d') + scatter_original = ax_original.scatter(data[0], data[1], data[2], c=colors, cmap=plt.cm.Spectral, s=5, picker=True) plt.axis('tight') plt.axis('off') plt.title('Original', fontsize=9) - ax = fig.add_subplot(122) - ax.scatter(embedded_data[0], embedded_data[1], c=colors, cmap=plt.cm.Spectral, s=5) + ax_embedding = fig.add_subplot(122) + scatter_embedding = ax_embedding.scatter(embedded_data[0], embedded_data[1], c=colors, cmap=plt.cm.Spectral, s=5, picker=True) plt.axis('tight') plt.axis('off') plt.title('Embedding' + (' with ' + method) if method else '', fontsize=9, wrap=True) + highlighted_points = [] # To store highlighted points + + # Function to highlight points on both plots + def highlight(index): + # Reset previous highlighted points + for point in highlighted_points: + point.remove() + highlighted_points.clear() + + # Highlight the current point on both scatter plots + point1 = ax_original.scatter([data[0][index]], [data[1][index]], [data[2][index]], color='white', s=25, edgecolor='black', zorder=3) + point2 = ax_embedding.scatter([embedded_data[0][index]], [embedded_data[1][index]], color='white', s=25, edgecolor='black', zorder=3) + highlighted_points.append(point1) + highlighted_points.append(point2) + fig.canvas.draw_idle() + + # Event handler for mouse motion + def on_hover(event): + if event.inaxes == ax_original: + cont, ind = scatter_original.contains(event) + elif event.inaxes == ax_embedding: + cont, ind = scatter_embedding.contains(event) + else: + return + + if cont: + index = ind['ind'][0] + highlight(index) + + fig.canvas.mpl_connect('motion_notify_event', on_hover) + plt.show() if __name__ == "__main__":