Skip to content

Commit b38f40a

Browse files
committed
Add cosine similarity
1 parent 8fd5d41 commit b38f40a

File tree

1 file changed

+105
-57
lines changed

1 file changed

+105
-57
lines changed

examples/compute_pairwise_distances.py

+105-57
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from matplotlib import pyplot as plt
1818

1919
from movement import sample_data
20-
from movement.kinematics import compute_pairwise_distances
20+
from movement.kinematics import (
21+
compute_forward_vector,
22+
compute_pairwise_distances,
23+
)
2124

2225
# %%
2326
# Load sample dataset
@@ -60,6 +63,8 @@
6063
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
6164
# Get reference length
6265

66+
# Measure the long side of the box in pixels
67+
# Note the lens is a bit distorted
6368
# Should I use diagonal?
6469
start_point = np.array([[209, 382]])
6570
end_point = np.array([[213, 1022]])
@@ -84,37 +89,6 @@
8489
ax.set_ylabel("y (pixels)")
8590
ax.set_title("Reference length")
8691

87-
# Measure the long side of the box in pixels
88-
89-
# Note the lens is a bit distorted
90-
91-
92-
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
93-
# Plot keypoints on sample frames
94-
95-
# time_sel = [5.0, 50.0, 100.0]
96-
97-
# # get colormap tab20
98-
# cmap = plt.get_cmap("tab20")
99-
100-
# fig, axs = plt.subplots(len(time_sel), 1, figsize=(10, 15))
101-
# for k in range(len(time_sel)):
102-
# for kpt_i, kpt in enumerate(ds.coords["keypoints"].data):
103-
# axs[k].scatter(
104-
# x=ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
105-
# y=ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
106-
# s=10,
107-
# label=f"{kpt}",
108-
# color=cmap(kpt_i),
109-
# )
110-
111-
# axs[k].axis("equal")
112-
# axs[k].invert_yaxis()
113-
# axs[k].set_xlabel("x (pixels)")
114-
# axs[k].set_ylabel("y (pixels)")
115-
# axs[k].set_title(f"Keypoints at {time_sel[k]} s")
116-
# # axs[k].legend(bbox_to_anchor=(1.1, 1.05))
117-
11892

11993
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
12094
# Compute distances between keypoints on different individuals
@@ -133,15 +107,15 @@
133107
# from individual 1 to all keypoints on individual 2
134108
print(inter_individual_kpt_distances.shape) # inter_individual_distances
135109

136-
# normalise with reference length?
137-
inter_individual_kpt_distances_norm = (
138-
inter_individual_kpt_distances / reference_length
139-
)
110+
# # normalise with reference length?
111+
# inter_individual_kpt_distances_norm = (
112+
# inter_individual_kpt_distances / reference_length
113+
# )
140114

141115
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
142116
# Plot matrix of distances and keypoints
143117
# Show different patterns / positions between the two animals
144-
# Note that the colorbars vary across frames!
118+
# Note that the colorbars vary across plots aka frames!
145119

146120
time_sel = [50.0, 100.0, 250.0]
147121

@@ -170,12 +144,6 @@
170144
label=f"{kpt}",
171145
color=cmap(kpt_i),
172146
)
173-
# # connect matching keypoints
174-
# axs[0].plot(
175-
# ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
176-
# ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
177-
# 'r'
178-
# )
179147

180148
# add text per individual
181149
for ind in ds.coords["individuals"].data:
@@ -229,7 +197,11 @@
229197
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
230198
# To get distance between homologous keypoints
231199
# get the diagonal of the previous matrix at each frame
232-
inter_individual_same_kpts = np.diagonal(inter_individual_kpt_distances)
200+
inter_individual_same_kpts = np.diagonal(
201+
inter_individual_kpt_distances,
202+
axis1=1,
203+
axis2=2,
204+
)
233205
print(inter_individual_same_kpts.shape) # (59999, 12)
234206

235207

@@ -240,45 +212,67 @@
240212
inter_individual_same_kpts[:, k_i],
241213
)
242214

215+
# # plot matrix as sparse matrix?
216+
# # plot vectors on top of a given frame?
217+
# for k in range(len(time_sel)):
218+
# fig, axs = plt.subplots(1, 2, figsize=(13, 5))
219+
# fig.subplots_adjust(wspace=0.5)
220+
221+
# # plot keypoints
222+
# for kpt_i, kpt in enumerate(ds.coords["keypoints"].data):
223+
# axs[0].scatter(
224+
# x=ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
225+
# y=ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
226+
# s=10,
227+
# label=f"{kpt}",
228+
# color=cmap(kpt_i),
229+
# )
230+
# # connect matching keypoints
231+
# axs[0].plot(
232+
# ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
233+
# ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
234+
# "r",
235+
# )
243236

244237
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
245238

246-
# To get specific keypoints between individuals
239+
# To get distance between specific keypoints on different individuals
247240
# e.g. snout of individual 1 to tail base of individual 2
248-
# yo can select the relevant keypoints along the dimensions
241+
# you can select the relevant keypoint coordinates along the dimensions
249242
# "individual1" and "individual2"
250243

251244
distance_snout_1_to_tail_2 = inter_individual_kpt_distances.sel(
252245
individual1="snout", individual2="tailbase"
253246
)
254247

255-
# plot distance over time
248+
# plot distance from snout 1 to tailbase 2 over time
249+
# plot in a short time window?
256250
fig, ax = plt.subplots()
257251
ax.plot(
258252
distance_snout_1_to_tail_2.time, # seconds
259253
distance_snout_1_to_tail_2 / reference_length,
260254
)
261255
ax.set_xlabel("time (seconds)")
262-
ax.set_ylabel("distance normalised")
256+
ax.set_ylabel("distance snout-to-tail normalised")
257+
263258

264-
# plot vectors on top of a given frame?
265259
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
266-
# Compute distances between the keypoints of the same individual
260+
# Compute distances between the keypoints on the same individual
267261
# compute average bodylength = snout to tailbase
268262

269263
distance_snout_to_tailbase_all = compute_pairwise_distances(
270264
ds.position,
271265
dim="keypoints",
272266
pairs={
273267
"snout": "tailbase",
274-
# this will set the dims of the output (individuals will be the
275-
# coordinates)
268+
# this will set the dims of the output
269+
# (individuals will be the coordinates)
276270
},
277271
) # pixels
278272

279273
print(distance_snout_to_tailbase_all) # dimensions are snout and tailbase!
280274

281-
# within individual
275+
# compute distances within individual
282276
bodylength_individual_1 = distance_snout_to_tailbase_all.sel(
283277
snout="individual1",
284278
tailbase="individual1",
@@ -289,7 +283,7 @@
289283
tailbase="individual2",
290284
)
291285

292-
# across individuals
286+
# compute distances across individuals
293287
# (an alternative way to the above)
294288
snout_1_to_tail_2 = distance_snout_to_tailbase_all.sel(
295289
snout="individual1",
@@ -300,13 +294,14 @@
300294
tailbase="individual1",
301295
)
302296

297+
# check that this approach is equivalent to the previous one
303298
np.testing.assert_almost_equal(
304299
snout_1_to_tail_2.data, distance_snout_1_to_tail_2.data
305300
)
306301

307302
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
308303
# Plot bodylength over time
309-
304+
# as a histogram instead?
310305
for b_i, bodylength_data_array in enumerate(
311306
[
312307
bodylength_individual_1,
@@ -330,7 +325,60 @@
330325
ax.set_ylabel("length (pixels)")
331326

332327
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
333-
# Try usage of 'all' and plot matrix of four quadrants
328+
# Try usage of 'all' and plot distance matrix with four quadrants
329+
330+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
331+
# Compute distances between centroids
332+
333+
distances_between_centroids = compute_pairwise_distances(
334+
ds.centroid,
335+
dim="individuals",
336+
pairs={
337+
"individual1": "individual2",
338+
},
339+
)
340+
341+
print(distances_between_centroids.shape) # (59999,)
342+
343+
# histogram
344+
fig, ax = plt.subplots()
345+
ax.hist(
346+
distances_between_centroids,
347+
)
348+
ax.set_xlabel("distance (pixels)")
349+
ax.set_ylabel("frames") # make it relative to the total number of frames?
350+
ax.set_title("Distances between centroids")
351+
352+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
353+
# Try a different metric, e.g cosine distance
354+
# https://en.wikipedia.org/wiki/Cosine_similarity
355+
356+
# compute forward vector per individual
357+
358+
ds["head_vector"] = compute_forward_vector(
359+
ds.position,
360+
left_keypoint="leftear",
361+
right_keypoint="rightear",
362+
camera_view="top_down",
363+
)
334364

365+
# compute cosine distance between forward vectors
366+
# 1 - dot product of unit vectors
367+
cosine_distance_head_vectors = compute_pairwise_distances(
368+
ds.head_vector,
369+
dim="individuals",
370+
pairs={
371+
"individual1": "individual2",
372+
},
373+
metric="cosine",
374+
)
375+
376+
# plot histogram
377+
# most of the time the vectors are antiparallel?
378+
fig, ax = plt.subplots() # figsize=(3, 3))
379+
ax.hist(
380+
cosine_distance_head_vectors,
381+
)
382+
ax.set_xlabel("cosine distance")
383+
ax.set_ylabel("frames")
335384
# %%
336-
# Try a different metric

0 commit comments

Comments
 (0)