-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy patharcface.cpp
73 lines (63 loc) · 2.09 KB
/
arcface.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/**
* Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved.
* Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt)
**/
#include <vector>
#include "common/tensors.hpp"
#include "common/math.hpp"
#include "arcface.hpp"
#include "hailo_tracker.hpp"
#include "hailo_xtensor.hpp"
#include "xtensor/xadapt.hpp"
#include "xtensor/xarray.hpp"
#define OUTPUT_LAYER_NAME_RGB "arcface_mobilefacenet/fc1"
#define OUTPUT_LAYER_NAME_RGBA "arcface_mobilefacenet_rgbx/fc1"
#define OUTPUT_LAYER_NAME_NV12 "arcface_mobilefacenet/fc1"
std::string tracker_name = "hailo_face_tracker";
void arcface(HailoROIPtr roi, std::string layer_name)
{
if (!roi->has_tensors())
{
return;
}
std::string jde_tracker_name = tracker_name + "_" + roi->get_stream_id();
auto unique_ids = hailo_common::get_hailo_track_id(roi);
// Remove previous matrices
if(unique_ids.empty())
roi->remove_objects_typed(HAILO_MATRIX);
else
HailoTracker::GetInstance().remove_matrices_from_track(jde_tracker_name, unique_ids[0]->get_id());
// Convert the tensor to xarray.
auto tensor = roi->get_tensor(layer_name);
xt::xarray<float> embeddings = common::get_xtensor_float(tensor);
// vector normalization
auto normalized_embedding = common::vector_normalization(embeddings);
HailoMatrixPtr hailo_matrix = hailo_common::create_matrix_ptr(normalized_embedding);
if(unique_ids.empty())
{
roi->add_object(hailo_matrix);
}
else
{
// Update the tracker with the results
HailoTracker::GetInstance().add_object_to_track(jde_tracker_name,
unique_ids[0]->get_id(),
hailo_matrix);
}
}
void arcface_rgb(HailoROIPtr roi)
{
arcface(roi, OUTPUT_LAYER_NAME_RGB);
}
void arcface_rgba(HailoROIPtr roi)
{
arcface(roi, OUTPUT_LAYER_NAME_RGBA);
}
void arcface_nv12(HailoROIPtr roi)
{
arcface(roi, OUTPUT_LAYER_NAME_NV12);
}
void filter(HailoROIPtr roi)
{
arcface(roi, OUTPUT_LAYER_NAME_RGB);
}