@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and 
1313limitations under the License. 
1414==============================================================================*/ 
15- import  { findKNN ,  findKNNGPUCosDistNorm ,  NearestEntry ,  TEST_ONLY }  from  './knn' ; 
16- import  { cosDistNorm ,   unit }  from  './vector' ; 
15+ import  { findKNN ,  findKNNGPUCosDist ,  NearestEntry ,  TEST_ONLY }  from  './knn' ; 
16+ import  { cosDist }  from  './vector' ; 
1717
1818describe ( 'projector knn test' ,  ( )  =>  { 
1919  function  getIndices ( nearest : NearestEntry [ ] [ ] ) : number [ ] [ ]  { 
@@ -22,22 +22,16 @@ describe('projector knn test', () => {
2222    } ) ; 
2323  } 
2424
25-   function  unitVector ( vector : Float32Array ) : Float32Array  { 
26-     // `unit` method replaces the vector in-place. 
27-     unit ( vector ) ; 
28-     return  vector ; 
29-   } 
30- 
31-   describe ( '#findKNNGPUCosDistNorm' ,  ( )  =>  { 
25+   describe ( '#findKNNGPUCosDist' ,  ( )  =>  { 
3226    it ( 'finds n-nearest neighbor for each item' ,  async  ( )  =>  { 
33-       const  values  =  await  findKNNGPUCosDistNorm ( 
27+       const  values  =  await  findKNNGPUCosDist ( 
3428        [ 
35-           { a : unitVector ( new  Float32Array ( [ 1 ,  2 ,  0 ] ) ) } , 
36-           { a : unitVector ( new  Float32Array ( [ 1 ,  1 ,  3 ] ) ) } , 
37-           { a : unitVector ( new  Float32Array ( [ 100 ,  30 ,  0 ] ) ) } , 
38-           { a : unitVector ( new  Float32Array ( [ 95 ,  23 ,  3 ] ) ) } , 
39-           { a : unitVector ( new  Float32Array ( [ 100 ,  10 ,  0 ] ) ) } , 
40-           { a : unitVector ( new  Float32Array ( [ 95 ,  23 ,  100 ] ) ) } , 
29+           { a : new  Float32Array ( [ 1 ,  2 ,  0 ] ) } , 
30+           { a : new  Float32Array ( [ 1 ,  1 ,  3 ] ) } , 
31+           { a : new  Float32Array ( [ 100 ,  30 ,  0 ] ) } , 
32+           { a : new  Float32Array ( [ 95 ,  23 ,  3 ] ) } , 
33+           { a : new  Float32Array ( [ 100 ,  10 ,  0 ] ) } , 
34+           { a : new  Float32Array ( [ 95 ,  23 ,  100 ] ) } , 
4135        ] , 
4236        4 , 
4337        ( data )  =>  data . a 
@@ -54,11 +48,8 @@ describe('projector knn test', () => {
5448    } ) ; 
5549
5650    it ( 'returns less than N when number of item is lower' ,  async  ( )  =>  { 
57-       const  values  =  await  findKNNGPUCosDistNorm ( 
58-         [ 
59-           unitVector ( new  Float32Array ( [ 1 ,  2 ,  0 ] ) ) , 
60-           unitVector ( new  Float32Array ( [ 1 ,  1 ,  3 ] ) ) , 
61-         ] , 
51+       const  values  =  await  findKNNGPUCosDist ( 
52+         [ new  Float32Array ( [ 1 ,  2 ,  0 ] ) ,  new  Float32Array ( [ 1 ,  1 ,  3 ] ) ] , 
6253        4 , 
6354        ( a )  =>  a 
6455      ) ; 
@@ -68,10 +59,8 @@ describe('projector knn test', () => {
6859
6960    it ( 'splits a large data into one that would fit into GPU memory' ,  async  ( )  =>  { 
7061      const  size  =  TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE  +  5 ; 
71-       const  data  =  new  Array ( size ) . fill ( 
72-         unitVector ( new  Float32Array ( [ 1 ,  1 ,  1 ] ) ) 
73-       ) ; 
74-       const  values  =  await  findKNNGPUCosDistNorm ( data ,  1 ,  ( a )  =>  a ) ; 
62+       const  data  =  new  Array ( size ) . fill ( new  Float32Array ( [ 1 ,  1 ,  1 ] ) ) ; 
63+       const  values  =  await  findKNNGPUCosDist ( data ,  1 ,  ( a )  =>  a ) ; 
7564
7665      expect ( getIndices ( values ) ) . toEqual ( [ 
7766        // Since distance to the diagonal entries (distance to self is 0) is 
@@ -84,25 +73,25 @@ describe('projector knn test', () => {
8473  } ) ; 
8574
8675  describe ( '#findKNN' ,  ( )  =>  { 
87-     // Covered by equality tests below (#findKNNGPUCosDistNorm  == #findKNN). 
76+     // Covered by equality tests below (#findKNNGPUCosDist  == #findKNN). 
8877  } ) ; 
8978
90-   describe ( '#findKNNGPUCosDistNorm  and #findKNN' ,  ( )  =>  { 
79+   describe ( '#findKNNGPUCosDist  and #findKNN' ,  ( )  =>  { 
9180    it ( 'returns same value when dist metrics are cosine' ,  async  ( )  =>  { 
9281      const  data  =  [ 
93-         unitVector ( new  Float32Array ( [ 1 ,  2 ,  0 ] ) ) , 
94-         unitVector ( new  Float32Array ( [ 1 ,  1 ,  3 ] ) ) , 
95-         unitVector ( new  Float32Array ( [ 100 ,  30 ,  0 ] ) ) , 
96-         unitVector ( new  Float32Array ( [ 95 ,  23 ,  3 ] ) ) , 
97-         unitVector ( new  Float32Array ( [ 100 ,  10 ,  0 ] ) ) , 
98-         unitVector ( new  Float32Array ( [ 95 ,  23 ,  100 ] ) ) , 
82+         new  Float32Array ( [ 1 ,  2 ,  0 ] ) , 
83+         new  Float32Array ( [ 1 ,  1 ,  3 ] ) , 
84+         new  Float32Array ( [ 100 ,  30 ,  0 ] ) , 
85+         new  Float32Array ( [ 95 ,  23 ,  3 ] ) , 
86+         new  Float32Array ( [ 100 ,  10 ,  0 ] ) , 
87+         new  Float32Array ( [ 95 ,  23 ,  100 ] ) , 
9988      ] ; 
100-       const  findKnnGpuCosVal  =  await  findKNNGPUCosDistNorm ( data ,  2 ,  ( a )  =>  a ) ; 
89+       const  findKnnGpuCosVal  =  await  findKNNGPUCosDist ( data ,  2 ,  ( a )  =>  a ) ; 
10190      const  findKnnVal  =  await  findKNN ( 
10291        data , 
10392        2 , 
10493        ( a )  =>  a , 
105-         ( a ,  b ,  limit )  =>  cosDistNorm ( a ,  b ) 
94+         ( a ,  b ,  limit )  =>  cosDist ( a ,  b ) 
10695      ) ; 
10796
10897      // Floating point precision makes it hard to test. Just assert indices. 
@@ -112,15 +101,15 @@ describe('projector knn test', () => {
112101    it ( 'splits a large data without the result being wrong' ,  async  ( )  =>  { 
113102      const  size  =  TEST_ONLY . OPTIMAL_GPU_BLOCK_SIZE  +  5 ; 
114103      const  data  =  Array . from ( new  Array ( size ) ) . map ( ( _ ,  index )  =>  { 
115-         return  unitVector ( new  Float32Array ( [ index  +  1 ,  index  +  1 ] ) ) ; 
104+         return  new  Float32Array ( [ index  +  1 ,  index  +  2 ] ) ; 
116105      } ) ; 
117106
118-       const  findKnnGpuCosVal  =  await  findKNNGPUCosDistNorm ( data ,  2 ,  ( a )  =>  a ) ; 
107+       const  findKnnGpuCosVal  =  await  findKNNGPUCosDist ( data ,  2 ,  ( a )  =>  a ) ; 
119108      const  findKnnVal  =  await  findKNN ( 
120109        data , 
121110        2 , 
122111        ( a )  =>  a , 
123-         ( a ,  b ,  limit )  =>  cosDistNorm ( a ,  b ) 
112+         ( a ,  b ,  limit )  =>  cosDist ( a ,  b ) 
124113      ) ; 
125114
126115      expect ( getIndices ( findKnnGpuCosVal ) ) . toEqual ( getIndices ( findKnnVal ) ) ; 
0 commit comments