@@ -49,16 +49,14 @@ export function findKNNGPUCosDistNorm<T>(
49
49
// pair of points, which we sort using KMin data structure to obtain the
50
50
// K nearest neighbors for each point.
51
51
const nearest : NearestEntry [ ] [ ] = new Array ( N ) ;
52
-
53
- const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
54
- const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
55
- const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
56
-
57
52
function step ( resolve : ( result : NearestEntry [ ] [ ] ) => void ) {
58
53
util
59
54
. runAsyncTask (
60
55
'Finding nearest neighbors...' ,
61
56
async ( ) => {
57
+ const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
58
+ const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
59
+ const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
62
60
// 1 - A * A^T.
63
61
const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
64
62
const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
@@ -68,6 +66,9 @@ export function findKNNGPUCosDistNorm<T>(
68
66
// [ 3 4 ],
69
67
// `.data()` returns [1, 2, 3, 4].
70
68
const partial = await cosDistMatrix . data ( ) ;
69
+ // Discard all tensors and free up the memory.
70
+ bigMatrix . dispose ( ) ;
71
+ bigMatrixTransposed . dispose ( ) ;
71
72
bigMatrixSquared . dispose ( ) ;
72
73
cosDistMatrix . dispose ( ) ;
73
74
for ( let i = 0 ; i < N ; i ++ ) {
@@ -93,15 +94,9 @@ export function findKNNGPUCosDistNorm<T>(
93
94
. then (
94
95
( ) => {
95
96
logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
96
- // Discard all tensors and free up the memory.
97
- bigMatrix . dispose ( ) ;
98
- bigMatrixTransposed . dispose ( ) ;
99
97
resolve ( nearest ) ;
100
98
} ,
101
99
( error ) => {
102
- // Discard all tensors and free up the memory.
103
- bigMatrix . dispose ( ) ;
104
- bigMatrixTransposed . dispose ( ) ;
105
100
// GPU failed. Reverting back to CPU.
106
101
logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
107
102
let distFunc = ( a , b , limit ) => vector . cosDistNorm ( a , b ) ;
0 commit comments