Skip to content

Commit 4f0cd50

Browse files
authored
Embedding Projector: reduce knn computation (#6321)
## Motivation for features / changes for t-SNE max number of knn neighbors is 3 * max_perplexity (3 * 100 = 300). For umap max number of neighbors is 100. let's just compute 300 knn neighbors so we can slice the result each time when we tweak perplexity for tSNE. ## Technical description of changes When computing KNN, use an upper threshold of 300 so any future computation with same number of points will be cached. ## Screenshots of UI changes N/A ## Detailed steps to verify changes work correctly (as executed by you) 1. Launch local app 2. Open t-SNE 3. adjust to Perplexity > 25 4. verify knn isn't computed a second time ## Alternate designs / implementations considered
1 parent d1c504f commit 4f0cd50

File tree

1 file changed

+15
-3
lines changed
  • tensorboard/plugins/projector/vz_projector

1 file changed

+15
-3
lines changed

tensorboard/plugins/projector/vz_projector/data.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ export const PCA_SAMPLE_DIM = 200;
9797
const NUM_PCA_COMPONENTS = 10;
9898
/** Id of message box used for umap optimization progress bar. */
9999
const UMAP_MSG_ID = 'umap-optimization';
100+
/** Minimum KNN neighbors threshold */
101+
const MIN_NUM_KNN_NEIGHBORS = 300;
100102
/**
101103
* Reserved metadata attributes used for sequence information
102104
* NOTE: Use "__seq_next__" as "__next__" is deprecated.
@@ -474,16 +476,26 @@ export class DataSet {
474476
);
475477
} else {
476478
const knnGpuEnabled = (await util.hasWebGLSupport()) && !IS_FIREFOX;
479+
const numKnnNeighborsToCompute = Math.max(
480+
nNeighbors,
481+
MIN_NUM_KNN_NEIGHBORS
482+
);
477483
const result = await (knnGpuEnabled
478-
? knn.findKNNGPUCosDistNorm(data, nNeighbors, (d) => d.vector)
484+
? knn.findKNNGPUCosDistNorm(
485+
data,
486+
numKnnNeighborsToCompute,
487+
(d) => d.vector
488+
)
479489
: knn.findKNN(
480490
data,
481-
nNeighbors,
491+
numKnnNeighborsToCompute,
482492
(d) => d.vector,
483493
(a, b) => vector.cosDistNorm(a, b)
484494
));
485495
this.nearest = result;
486-
return Promise.resolve(result);
496+
return Promise.resolve(
497+
result.map((neighbors) => neighbors.slice(0, nNeighbors))
498+
);
487499
}
488500
}
489501
/* Perturb TSNE and update dataset point coordinates. */

0 commit comments

Comments
 (0)