17
17
from inspect import signature
18
18
from .backend import get_backend , Backend , NumpyBackend , JaxBackend
19
19
20
- __time_tic_toc = time .time ()
20
+ __time_tic_toc = time .perf_counter ()
21
21
22
22
23
23
def tic ():
24
24
r"""Python implementation of Matlab tic() function"""
25
25
global __time_tic_toc
26
- __time_tic_toc = time .time ()
26
+ __time_tic_toc = time .perf_counter ()
27
27
28
28
29
29
def toc (message = "Elapsed time : {} s" ):
30
30
r"""Python implementation of Matlab toc() function"""
31
- t = time .time ()
31
+ t = time .perf_counter ()
32
32
print (message .format (t - __time_tic_toc ))
33
33
return t - __time_tic_toc
34
34
35
35
36
36
def toq ():
37
37
r"""Python implementation of Julia toc() function"""
38
- t = time .time ()
38
+ t = time .perf_counter ()
39
39
return t - __time_tic_toc
40
40
41
41
@@ -251,7 +251,7 @@ def clean_zeros(a, b, M):
251
251
return a2 , b2 , M2
252
252
253
253
254
- def euclidean_distances (X , Y , squared = False ):
254
+ def euclidean_distances (X , Y , squared = False , nx = None ):
255
255
r"""
256
256
Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
257
257
distance matrix between each pair of vectors.
@@ -270,13 +270,13 @@ def euclidean_distances(X, Y, squared=False):
270
270
-------
271
271
distances : array-like, shape (`n_samples_1`, `n_samples_2`)
272
272
"""
273
-
274
- nx = get_backend (X , Y )
273
+ if nx is None :
274
+ nx = get_backend (X , Y )
275
275
276
276
a2 = nx .einsum ("ij,ij->i" , X , X )
277
277
b2 = nx .einsum ("ij,ij->i" , Y , Y )
278
278
279
- c = - 2 * nx .dot (X , Y . T )
279
+ c = - 2 * nx .dot (X , nx . transpose ( Y ) )
280
280
c += a2 [:, None ]
281
281
c += b2 [None , :]
282
282
@@ -291,11 +291,21 @@ def euclidean_distances(X, Y, squared=False):
291
291
return c
292
292
293
293
294
- def dist (x1 , x2 = None , metric = "sqeuclidean" , p = 2 , w = None ):
294
+ def dist (
295
+ x1 ,
296
+ x2 = None ,
297
+ metric = "sqeuclidean" ,
298
+ p = 2 ,
299
+ w = None ,
300
+ backend = "auto" ,
301
+ nx = None ,
302
+ use_tensor = False ,
303
+ ):
295
304
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
296
305
297
306
.. note:: This function is backend-compatible and will work on arrays
298
- from all compatible backends.
307
+ from all compatible backends for the following metrics:
308
+ 'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.
299
309
300
310
Parameters
301
311
----------
@@ -315,7 +325,17 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
315
325
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
316
326
w : array-like, rank 1
317
327
Weights for the weighted metrics.
318
-
328
+ backend : str, optional
329
+ Backend to use for the computation. If 'auto', the backend is
330
+ automatically selected based on the input data. if 'scipy',
331
+ the ``scipy.spatial.distance.cdist`` function is used (and gradients are
332
+ detached).
333
+ use_tensor : bool, optional
334
+ If true use tensorized computation for the distance matrix which can
335
+ cause memory issues for large datasets. Default is False and the
336
+ parameter is used only for the 'cityblock' and 'minkowski' metrics.
337
+ nx : Backend, optional
338
+ Backend to perform computations on. If omitted, the backend defaults to that of `x1`.
319
339
320
340
Returns
321
341
-------
@@ -324,12 +344,69 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
324
344
distance matrix computed with given metric
325
345
326
346
"""
347
+ if nx is None :
348
+ nx = get_backend (x1 , x2 )
327
349
if x2 is None :
328
350
x2 = x1
329
- if metric == "sqeuclidean" :
330
- return euclidean_distances (x1 , x2 , squared = True )
351
+ if backend == "scipy" : # force scipy backend with cdist function
352
+ x1 = nx .to_numpy (x1 )
353
+ x2 = nx .to_numpy (x2 )
354
+ if isinstance (metric , str ) and metric .endswith ("minkowski" ):
355
+ return nx .from_numpy (cdist (x1 , x2 , metric = metric , p = p , w = w ))
356
+ if w is not None :
357
+ return nx .from_numpy (cdist (x1 , x2 , metric = metric , w = w ))
358
+ return nx .from_numpy (cdist (x1 , x2 , metric = metric ))
359
+ elif metric == "sqeuclidean" :
360
+ return euclidean_distances (x1 , x2 , squared = True , nx = nx )
331
361
elif metric == "euclidean" :
332
- return euclidean_distances (x1 , x2 , squared = False )
362
+ return euclidean_distances (x1 , x2 , squared = False , nx = nx )
363
+ elif metric == "cityblock" :
364
+ if use_tensor :
365
+ return nx .sum (nx .abs (x1 [:, None , :] - x2 [None , :, :]), axis = 2 )
366
+ else :
367
+ M = 0.0
368
+ for i in range (x1 .shape [1 ]):
369
+ M += nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :])
370
+ return M
371
+ elif metric == "minkowski" :
372
+ if w is None :
373
+ if use_tensor :
374
+ return nx .power (
375
+ nx .sum (
376
+ nx .power (nx .abs (x1 [:, None , :] - x2 [None , :, :]), p ), axis = 2
377
+ ),
378
+ 1 / p ,
379
+ )
380
+ else :
381
+ M = 0.0
382
+ for i in range (x1 .shape [1 ]):
383
+ M += nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :]) ** p
384
+ return M ** (1 / p )
385
+ else :
386
+ if use_tensor :
387
+ return nx .power (
388
+ nx .sum (
389
+ w [None , None , :]
390
+ * nx .power (nx .abs (x1 [:, None , :] - x2 [None , :, :]), p ),
391
+ axis = 2 ,
392
+ ),
393
+ 1 / p ,
394
+ )
395
+ else :
396
+ M = 0.0
397
+ for i in range (x1 .shape [1 ]):
398
+ M += w [i ] * nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :]) ** p
399
+ return M ** (1 / p )
400
+ elif metric == "cosine" :
401
+ nx1 = nx .sqrt (nx .einsum ("ij,ij->i" , x1 , x1 ))
402
+ nx2 = nx .sqrt (nx .einsum ("ij,ij->i" , x2 , x2 ))
403
+ return 1.0 - (nx .dot (x1 , nx .transpose (x2 )) / nx1 [:, None ] / nx2 [None , :])
404
+ elif metric == "correlation" :
405
+ x1 = x1 - nx .mean (x1 , axis = 1 )[:, None ]
406
+ x2 = x2 - nx .mean (x2 , axis = 1 )[:, None ]
407
+ nx1 = nx .sqrt (nx .einsum ("ij,ij->i" , x1 , x1 ))
408
+ nx2 = nx .sqrt (nx .einsum ("ij,ij->i" , x2 , x2 ))
409
+ return 1.0 - (nx .dot (x1 , nx .transpose (x2 )) / nx1 [:, None ] / nx2 [None , :])
333
410
else :
334
411
if not get_backend (x1 , x2 ).__name__ == "numpy" :
335
412
raise NotImplementedError ()
0 commit comments