Skip to content

Commit bc348eb

Browse files
committed
Merge pull request opencv#9963 from dkurt:fix_caffe_shrinker
2 parents 6e4dacc + e1ebc4e commit bc348eb

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

modules/dnn/include/opencv2/dnn/dnn.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,13 +726,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
726726
* @param src Path to origin model from Caffe framework contains single
727727
* precision floating point weights (usually has `.caffemodel` extension).
728728
* @param dst Path to destination model with updated weights.
729+
* @param layersTypes Set of layers types which parameters will be converted.
730+
* By default, converts only Convolutional and Fully-Connected layers'
731+
* weights.
729732
*
730733
* @note Shrinked model has no origin float32 weights so it can't be used
731734
* in origin Caffe framework anymore. However the structure of data
732735
* is taken from NVidia's Caffe fork: https://github.com/NVIDIA/caffe.
733736
* So the resulting model may be used there.
734737
*/
735-
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst);
738+
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
739+
const std::vector<String>& layersTypes = std::vector<String>());
736740

737741
/** @brief Performs non maximum suppression given boxes and corresponding scores.
738742

modules/dnn/src/caffe/caffe_shrinker.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,27 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
1717

1818
#ifdef HAVE_PROTOBUF
1919

20-
void shrinkCaffeModel(const String& src, const String& dst)
20+
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& layersTypes)
2121
{
2222
CV_TRACE_FUNCTION();
2323

24+
std::vector<String> types(layersTypes);
25+
if (types.empty())
26+
{
27+
types.push_back("Convolution");
28+
types.push_back("InnerProduct");
29+
}
30+
2431
caffe::NetParameter net;
2532
ReadNetParamsFromBinaryFileOrDie(src.c_str(), &net);
2633

2734
for (int i = 0; i < net.layer_size(); ++i)
2835
{
2936
caffe::LayerParameter* lp = net.mutable_layer(i);
37+
if (std::find(types.begin(), types.end(), lp->type()) == types.end())
38+
{
39+
continue;
40+
}
3041
for (int j = 0; j < lp->blobs_size(); ++j)
3142
{
3243
caffe::BlobProto* blob = lp->mutable_blobs(j);
@@ -54,7 +65,7 @@ void shrinkCaffeModel(const String& src, const String& dst)
5465

5566
#else
5667

57-
void shrinkCaffeModel(const String& src, const String& dst)
68+
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& types)
5869
{
5970
CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models");
6071
}

0 commit comments

Comments
 (0)