File tree 2 files changed +18
-3
lines changed
2 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -726,13 +726,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
726
726
* @param src Path to origin model from Caffe framework contains single
727
727
* precision floating point weights (usually has `.caffemodel` extension).
728
728
* @param dst Path to destination model with updated weights.
729
+ * @param layersTypes Set of layers types which parameters will be converted.
730
+ * By default, converts only Convolutional and Fully-Connected layers'
731
+ * weights.
729
732
*
730
733
* @note Shrinked model has no origin float32 weights so it can't be used
731
734
* in origin Caffe framework anymore. However the structure of data
732
735
* is taken from NVidia's Caffe fork: https://github.com/NVIDIA/caffe.
733
736
* So the resulting model may be used there.
734
737
*/
735
- CV_EXPORTS_W void shrinkCaffeModel (const String& src, const String& dst);
738
+ CV_EXPORTS_W void shrinkCaffeModel (const String& src, const String& dst,
739
+ const std::vector<String>& layersTypes = std::vector<String>());
736
740
737
741
/* * @brief Performs non maximum suppression given boxes and corresponding scores.
738
742
Original file line number Diff line number Diff line change @@ -17,16 +17,27 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
17
17
18
18
#ifdef HAVE_PROTOBUF
19
19
20
- void shrinkCaffeModel (const String& src, const String& dst)
20
+ void shrinkCaffeModel (const String& src, const String& dst, const std::vector<String>& layersTypes )
21
21
{
22
22
CV_TRACE_FUNCTION ();
23
23
24
+ std::vector<String> types (layersTypes);
25
+ if (types.empty ())
26
+ {
27
+ types.push_back (" Convolution" );
28
+ types.push_back (" InnerProduct" );
29
+ }
30
+
24
31
caffe::NetParameter net;
25
32
ReadNetParamsFromBinaryFileOrDie (src.c_str (), &net);
26
33
27
34
for (int i = 0 ; i < net.layer_size (); ++i)
28
35
{
29
36
caffe::LayerParameter* lp = net.mutable_layer (i);
37
+ if (std::find (types.begin (), types.end (), lp->type ()) == types.end ())
38
+ {
39
+ continue ;
40
+ }
30
41
for (int j = 0 ; j < lp->blobs_size (); ++j)
31
42
{
32
43
caffe::BlobProto* blob = lp->mutable_blobs (j);
@@ -54,7 +65,7 @@ void shrinkCaffeModel(const String& src, const String& dst)
54
65
55
66
#else
56
67
57
- void shrinkCaffeModel (const String& src, const String& dst)
68
+ void shrinkCaffeModel (const String& src, const String& dst, const std::vector<String>& types )
58
69
{
59
70
CV_Error (cv::Error::StsNotImplemented, " libprotobuf required to import data from Caffe models" );
60
71
}
You can’t perform that action at this time.
0 commit comments