Skip to content

Commit 412f18d

Browse files
committed
1D blob handling in MSRA/Xavier fillers
1 parent 379a3ba commit 412f18d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

include/caffe/filler.hpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,10 @@ class XavierFiller : public Filler<Dtype> {
148148
virtual void Fill(Blob<Dtype>* blob) {
149149
CHECK(blob->count());
150150
int fan_in = blob->count() / blob->shape(0);
151-
// Compatible for ND Convolution
152-
int fan_out = blob->count() / blob->shape(1);
151+
// Compatibility with ND blobs
152+
int fan_out = blob->num_axes() > 1 ?
153+
blob->count() / blob->shape(1) :
154+
blob->count();
153155
Dtype n = fan_in; // default to fan_in
154156
if (this->filler_param_.variance_norm() ==
155157
FillerParameter_VarianceNorm_AVERAGE) {
@@ -191,8 +193,10 @@ class MSRAFiller : public Filler<Dtype> {
191193
virtual void Fill(Blob<Dtype>* blob) {
192194
CHECK(blob->count());
193195
int fan_in = blob->count() / blob->shape(0);
194-
// Compatible for ND Convolution
195-
int fan_out = blob->count() / blob->shape(1);
196+
// Compatibility with ND blobs
197+
int fan_out = blob->num_axes() > 1 ?
198+
blob->count() / blob->shape(1) :
199+
blob->count();
196200
Dtype n = fan_in; // default to fan_in
197201
if (this->filler_param_.variance_norm() ==
198202
FillerParameter_VarianceNorm_AVERAGE) {

0 commit comments

Comments
 (0)