Skip to content

Commit f561682

Browse files
committed
Merge pull request BVLC#3755 from shelhamer/fix-upgrade-proto
Fix Upgrade Net Tools
2 parents 358b60c + 7eaeb3a commit f561682

4 files changed

+41
-34
lines changed

src/caffe/util/upgrade_proto.cpp

+34-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
namespace caffe {
1414

1515
bool NetNeedsUpgrade(const NetParameter& net_param) {
16-
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param);
16+
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param)
17+
|| NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param);
1718
}
1819

1920
bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
@@ -655,12 +656,14 @@ void UpgradeNetDataTransformation(NetParameter* net_param) {
655656
}
656657

657658
bool UpgradeV1Net(const NetParameter& v1_net_param, NetParameter* net_param) {
658-
bool is_fully_compatible = true;
659659
if (v1_net_param.layer_size() > 0) {
660-
LOG(ERROR) << "Input NetParameter to be upgraded already specifies 'layer' "
661-
<< "fields; these will be ignored for the upgrade.";
662-
is_fully_compatible = false;
660+
LOG(FATAL) << "Refusing to upgrade inconsistent NetParameter input; "
661+
<< "the definition includes both 'layer' and 'layers' fields. "
662+
<< "The current format defines 'layer' fields with string type like "
663+
<< "layer { type: 'Layer' ... } and not layers { type: LAYER ... }. "
664+
<< "Manually switch the definition to 'layer' format to continue.";
663665
}
666+
bool is_fully_compatible = true;
664667
net_param->CopyFrom(v1_net_param);
665668
net_param->clear_layers();
666669
net_param->clear_layer();
@@ -952,29 +955,35 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param) {
952955
}
953956

954957
void UpgradeNetInput(NetParameter* net_param) {
955-
LayerParameter* layer_param = net_param->add_layer();
956-
layer_param->set_name("input");
957-
layer_param->set_type("Input");
958-
InputParameter* input_param = layer_param->mutable_input_param();
958+
// Collect inputs and convert to Input layer definitions.
959+
// If the NetParameter holds an input alone, without shape/dim, then
960+
// it's a legacy caffemodel and simply stripping the input field is enough.
959961
bool has_shape = net_param->input_shape_size() > 0;
960-
// Convert input fields into a layer.
961-
for (int i = 0; i < net_param->input_size(); ++i) {
962-
layer_param->add_top(net_param->input(i));
963-
if (has_shape) {
964-
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
965-
} else {
966-
// Turn legacy input dimensions into shape.
967-
BlobShape* shape = input_param->add_shape();
968-
int first_dim = i*4;
969-
int last_dim = first_dim + 4;
970-
for (int j = first_dim; j < last_dim; j++) {
971-
shape->add_dim(net_param->input_dim(j));
962+
bool has_dim = net_param->input_dim_size() > 0;
963+
if (has_shape || has_dim) {
964+
LayerParameter* layer_param = net_param->add_layer();
965+
layer_param->set_name("input");
966+
layer_param->set_type("Input");
967+
InputParameter* input_param = layer_param->mutable_input_param();
968+
// Convert input fields into a layer.
969+
for (int i = 0; i < net_param->input_size(); ++i) {
970+
layer_param->add_top(net_param->input(i));
971+
if (has_shape) {
972+
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
973+
} else {
974+
// Turn legacy input dimensions into shape.
975+
BlobShape* shape = input_param->add_shape();
976+
int first_dim = i*4;
977+
int last_dim = first_dim + 4;
978+
for (int j = first_dim; j < last_dim; j++) {
979+
shape->add_dim(net_param->input_dim(j));
980+
}
972981
}
973982
}
974-
}
975-
// Swap input layer to beginning of net to satisfy layer dependencies.
976-
for (int i = net_param->layer_size() - 1; i > 0; --i) {
977-
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
983+
// Swap input layer to beginning of net to satisfy layer dependencies.
984+
for (int i = net_param->layer_size() - 1; i > 0; --i) {
985+
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
986+
}
978987
}
979988
// Clear inputs.
980989
net_param->clear_input();

tools/upgrade_net_proto_binary.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using std::ofstream;
1616
using namespace caffe; // NOLINT(build/namespaces)
1717

1818
int main(int argc, char** argv) {
19+
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
1920
::google::InitGoogleLogging(argv[0]);
2021
if (argc != 3) {
2122
LOG(ERROR) << "Usage: "
@@ -39,11 +40,11 @@ int main(int argc, char** argv) {
3940
<< "see details above.";
4041
}
4142
} else {
42-
LOG(ERROR) << "File already in V1 proto format: " << argv[1];
43+
LOG(ERROR) << "File already in latest proto format: " << input_filename;
4344
}
4445

4546
WriteProtoToBinaryFile(net_param, argv[2]);
4647

47-
LOG(ERROR) << "Wrote upgraded NetParameter binary proto to " << argv[2];
48+
LOG(INFO) << "Wrote upgraded NetParameter binary proto to " << argv[2];
4849
return !success;
4950
}

tools/upgrade_net_proto_text.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using std::ofstream;
1616
using namespace caffe; // NOLINT(build/namespaces)
1717

1818
int main(int argc, char** argv) {
19+
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
1920
::google::InitGoogleLogging(argv[0]);
2021
if (argc != 3) {
2122
LOG(ERROR) << "Usage: "
@@ -31,7 +32,6 @@ int main(int argc, char** argv) {
3132
return 2;
3233
}
3334
bool need_upgrade = NetNeedsUpgrade(net_param);
34-
bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
3535
bool success = true;
3636
if (need_upgrade) {
3737
success = UpgradeNetAsNeeded(input_filename, &net_param);
@@ -43,13 +43,9 @@ int main(int argc, char** argv) {
4343
LOG(ERROR) << "File already in latest proto format: " << input_filename;
4444
}
4545

46-
if (need_data_upgrade) {
47-
UpgradeNetDataTransformation(&net_param);
48-
}
49-
5046
// Save new format prototxt.
5147
WriteProtoToTextFile(net_param, argv[2]);
5248

53-
LOG(ERROR) << "Wrote upgraded NetParameter text proto to " << argv[2];
49+
LOG(INFO) << "Wrote upgraded NetParameter text proto to " << argv[2];
5450
return !success;
5551
}

tools/upgrade_solver_proto_text.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using std::ofstream;
1616
using namespace caffe; // NOLINT(build/namespaces)
1717

1818
int main(int argc, char** argv) {
19+
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
1920
::google::InitGoogleLogging(argv[0]);
2021
if (argc != 3) {
2122
LOG(ERROR) << "Usage: upgrade_solver_proto_text "
@@ -45,6 +46,6 @@ int main(int argc, char** argv) {
4546
// Save new format prototxt.
4647
WriteProtoToTextFile(solver_param, argv[2]);
4748

48-
LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2];
49+
LOG(INFO) << "Wrote upgraded SolverParameter text proto to " << argv[2];
4950
return !success;
5051
}

0 commit comments

Comments
 (0)