@@ -22,17 +22,9 @@ export type NearestEntry = {
2222  index : number ; 
2323  dist : number ; 
2424} ; 
25- /** 
26-  * Optimal size for the height of the matrix when doing computation on the GPU 
27-  * using WebGL. This was found experimentally. 
28-  * 
29-  * This also guarantees that for computing pair-wise distance for up to 10K 
30-  * vectors, no more than 40MB will be allocated in the GPU. Without the 
31-  * allocation limit, we can freeze the graphics of the whole OS. 
32-  */ 
33- const  OPTIMAL_GPU_BLOCK_SIZE  =  256 ; 
34- /** Id of message box used for knn gpu progress bar. */ 
35- const  KNN_GPU_MSG_ID  =  'knn-gpu' ; 
25+ 
26+ /** Id of message box used for knn. */ 
27+ const  KNN_MSG_ID  =  'knn' ; 
3628
3729/** 
3830 * Returns the K nearest neighbors for each vector where the distance 
@@ -52,105 +44,63 @@ export function findKNNGPUCosDistNorm<T>(
5244  const  N  =  dataPoints . length ; 
5345  const  dim  =  accessor ( dataPoints [ 0 ] ) . length ; 
5446  // The goal is to compute a large matrix multiplication A*A.T where A is of 
55-   // size NxD and A.T is its transpose. This results in a NxN matrix which 
56-   // could be too big to store on the GPU memory. To avoid memory overflow, we 
57-   // compute multiple A*partial_A.T where partial_A is of size BxD (B is much 
58-   // smaller than N). This results in storing only NxB size matrices on the GPU 
59-   // at a given time. 
47+   // size NxD and A.T is its transpose. This results in a NxN matrix. 
6048  // A*A.T will give us NxN matrix holding the cosine distance between every 
6149  // pair of points, which we sort using KMin data structure to obtain the 
6250  // K nearest neighbors for each point. 
6351  const  nearest : NearestEntry [ ] [ ]  =  new  Array ( N ) ; 
64-   let  numPieces  =  Math . ceil ( N  /  OPTIMAL_GPU_BLOCK_SIZE ) ; 
65-   const  actualPieceSize  =  Math . floor ( N  /  numPieces ) ; 
66-   const  modulo  =  N  %  actualPieceSize ; 
67-   numPieces  +=  modulo  ? 1  : 0 ; 
68-   let  offset  =  0 ; 
69-   let  progress  =  0 ; 
70-   let  progressDiff  =  1  /  ( 2  *  numPieces ) ; 
71-   let  piece  =  0 ; 
72- 
73-   const  typedArray  =  vector . toTypedArray ( dataPoints ,  accessor ) ; 
74-   const  bigMatrix  =  tf . tensor ( typedArray ,  [ N ,  dim ] ) ; 
75-   const  bigMatrixTransposed  =  tf . transpose ( bigMatrix ) ; 
76-   // 1 - A * A^T. 
77-   const  bigMatrixSquared  =  tf . matMul ( bigMatrix ,  bigMatrixTransposed ) ; 
78-   const  cosDistMatrix  =  tf . sub ( 1 ,  bigMatrixSquared ) ; 
79- 
80-   let  maybePaddedCosDistMatrix  =  cosDistMatrix ; 
81-   if  ( actualPieceSize  *  numPieces  >  N )  { 
82-     // Expect the input to be rank 2 (though it is not typed that way) so we 
83-     // want to pad the first dimension so we split very evenly (all splitted 
84-     // tensor have exactly the same dimesion). 
85-     const  padding : Array < [ number ,  number ] >  =  [ 
86-       [ 0 ,  actualPieceSize  *  numPieces  -  N ] , 
87-       [ 0 ,  0 ] , 
88-     ] ; 
89-     maybePaddedCosDistMatrix  =  tf . pad ( cosDistMatrix ,  padding ) ; 
90-   } 
91-   const  splits  =  tf . split ( 
92-     maybePaddedCosDistMatrix , 
93-     new  Array ( numPieces ) . fill ( actualPieceSize ) , 
94-     0 
95-   ) ; 
96- 
9752  function  step ( resolve : ( result : NearestEntry [ ] [ ] )  =>  void )  { 
98-     let  progressMsg  = 
99-       'Finding nearest neighbors: '  +  ( progress  *  100 ) . toFixed ( )  +  '%' ; 
10053    util 
10154      . runAsyncTask ( 
102-         progressMsg , 
55+         'Finding nearest neighbors...' , 
10356        async  ( )  =>  { 
57+           const  cosSimilarityMatrix  =  tf . tidy ( ( )  =>  { 
58+             const  typedArray  =  vector . toTypedArray ( dataPoints ,  accessor ) ; 
59+             const  bigMatrix  =  tf . tensor ( typedArray ,  [ N ,  dim ] ) ; 
60+             const  bigMatrixTransposed  =  tf . transpose ( bigMatrix ) ; 
61+             // A * A^T. 
62+             return  tf . matMul ( bigMatrix ,  bigMatrixTransposed ) ; 
63+           } ) ; 
10464          // `.data()` returns flattened Float32Array of B * N dimension. 
10565          // For matrix of 
10666          // [ 1  2 ] 
10767          // [ 3  4 ], 
10868          // `.data()` returns [1, 2, 3, 4]. 
109-           const  partial  =  await  splits [ piece ] . data ( ) ; 
110-           progress  +=  progressDiff ; 
111-           for  ( let  i  =  0 ;  i  <  actualPieceSize ;  i ++ )  { 
69+           let  partial ; 
70+           try  { 
71+             partial  =  await  cosSimilarityMatrix . data ( ) ; 
72+           }  finally  { 
73+             // Discard all tensors and free up the memory. 
74+             cosSimilarityMatrix . dispose ( ) ; 
75+           } 
76+           for  ( let  i  =  0 ;  i  <  N ;  i ++ )  { 
11277            let  kMin  =  new  KMin < NearestEntry > ( k ) ; 
113-             let  iReal  =  offset  +  i ; 
114-             if  ( iReal  >=  N )  break ; 
11578            for  ( let  j  =  0 ;  j  <  N ;  j ++ )  { 
11679              // Skip diagonal entries. 
117-               if  ( j  ===  iReal )  { 
80+               if  ( j  ===  i )  { 
11881                continue ; 
11982              } 
12083              // Access i * N's row at `j` column. 
12184              // Reach row has N entries and j-th index has cosine distance 
122-               // between iReal  vs. j-th vectors. 
123-               const  cosDist  =  partial [ i  *  N  +  j ] ; 
85+               // between i-th  vs. j-th vectors. 
86+               const  cosDist  =  1   -   partial [ i  *  N  +  j ] ; 
12487              if  ( cosDist  >=  0 )  { 
12588                kMin . add ( cosDist ,  { index : j ,  dist : cosDist } ) ; 
12689              } 
12790            } 
128-             nearest [ iReal ]  =  kMin . getMinKItems ( ) ; 
91+             nearest [ i ]  =  kMin . getMinKItems ( ) ; 
12992          } 
130-           progress  +=  progressDiff ; 
131-           offset  +=  actualPieceSize ; 
132-           piece ++ ; 
13393        } , 
134-         KNN_GPU_MSG_ID 
94+         KNN_MSG_ID 
13595      ) 
13696      . then ( 
13797        ( )  =>  { 
138-           if  ( piece  <  numPieces )  { 
139-             step ( resolve ) ; 
140-           }  else  { 
141-             logging . setModalMessage ( null ! ,  KNN_GPU_MSG_ID ) ; 
142-             // Discard all tensors and free up the memory. 
143-             bigMatrix . dispose ( ) ; 
144-             bigMatrixTransposed . dispose ( ) ; 
145-             bigMatrixSquared . dispose ( ) ; 
146-             cosDistMatrix . dispose ( ) ; 
147-             splits . forEach ( ( split )  =>  split . dispose ( ) ) ; 
148-             resolve ( nearest ) ; 
149-           } 
98+           logging . setModalMessage ( null ! ,  KNN_MSG_ID ) ; 
99+           resolve ( nearest ) ; 
150100        } , 
151101        ( error )  =>  { 
152102          // GPU failed. Reverting back to CPU. 
153-           logging . setModalMessage ( null ! ,  KNN_GPU_MSG_ID ) ; 
103+           logging . setModalMessage ( null ! ,  KNN_MSG_ID ) ; 
154104          let  distFunc  =  ( a ,  b ,  limit )  =>  vector . cosDistNorm ( a ,  b ) ; 
155105          findKNN ( dataPoints ,  k ,  accessor ,  distFunc ) . then ( ( nearest )  =>  { 
156106            resolve ( nearest ) ; 
@@ -212,47 +162,12 @@ export function findKNN<T>(
212162      for  ( let  i  =  0 ;  i  <  N ;  i ++ )  { 
213163        nearest [ i ]  =  kMin [ i ] . getMinKItems ( ) ; 
214164      } 
165+       logging . setModalMessage ( null ! ,  KNN_MSG_ID ) ; 
215166      return  nearest ; 
216-     } 
167+     } , 
168+     KNN_MSG_ID 
217169  ) ; 
218170} 
219- /** Calculates the minimum distance between a search point and a rectangle. */ 
220- function  minDist ( 
221-   point : [ number ,  number ] , 
222-   x1 : number , 
223-   y1 : number , 
224-   x2 : number , 
225-   y2 : number 
226- )  { 
227-   let  x  =  point [ 0 ] ; 
228-   let  y  =  point [ 1 ] ; 
229-   let  dx1  =  x  -  x1 ; 
230-   let  dx2  =  x  -  x2 ; 
231-   let  dy1  =  y  -  y1 ; 
232-   let  dy2  =  y  -  y2 ; 
233-   if  ( dx1  *  dx2  <=  0 )  { 
234-     // x is between x1 and x2 
235-     if  ( dy1  *  dy2  <=  0 )  { 
236-       // (x,y) is inside the rectangle 
237-       return  0 ;  // return 0 as point is in rect 
238-     } 
239-     return  Math . min ( Math . abs ( dy1 ) ,  Math . abs ( dy2 ) ) ; 
240-   } 
241-   if  ( dy1  *  dy2  <=  0 )  { 
242-     // y is between y1 and y2 
243-     // We know it is already inside the rectangle 
244-     return  Math . min ( Math . abs ( dx1 ) ,  Math . abs ( dx2 ) ) ; 
245-   } 
246-   let  corner : [ number ,  number ] ; 
247-   if  ( x  >  x2 )  { 
248-     // Upper-right vs lower-right. 
249-     corner  =  y  >  y2  ? [ x2 ,  y2 ]  : [ x2 ,  y1 ] ; 
250-   }  else  { 
251-     // Upper-left vs lower-left. 
252-     corner  =  y  >  y2  ? [ x1 ,  y2 ]  : [ x1 ,  y1 ] ; 
253-   } 
254-   return  Math . sqrt ( vector . dist22D ( [ x ,  y ] ,  corner ) ) ; 
255- } 
256171/** 
257172 * Returns the nearest neighbors of a particular point. 
258173 * 
@@ -281,5 +196,3 @@ export function findKNNofPoint<T>(
281196  } 
282197  return  kMin . getMinKItems ( ) ; 
283198} 
284- 
285- export  const  TEST_ONLY  =  { OPTIMAL_GPU_BLOCK_SIZE } ; 
0 commit comments