File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -148,8 +148,10 @@ class XavierFiller : public Filler<Dtype> {
148148 virtual void Fill (Blob<Dtype>* blob) {
149149 CHECK (blob->count ());
150150 int fan_in = blob->count () / blob->shape (0 );
151- // Compatible for ND Convolution
152- int fan_out = blob->count () / blob->shape (1 );
151+ // Compatibility with ND blobs
152+ int fan_out = blob->num_axes () > 1 ?
153+ blob->count () / blob->shape (1 ) :
154+ blob->count ();
153155 Dtype n = fan_in; // default to fan_in
154156 if (this ->filler_param_ .variance_norm () ==
155157 FillerParameter_VarianceNorm_AVERAGE) {
@@ -191,8 +193,10 @@ class MSRAFiller : public Filler<Dtype> {
191193 virtual void Fill (Blob<Dtype>* blob) {
192194 CHECK (blob->count ());
193195 int fan_in = blob->count () / blob->shape (0 );
194- // Compatible for ND Convolution
195- int fan_out = blob->count () / blob->shape (1 );
196+ // Compatibility with ND blobs
197+ int fan_out = blob->num_axes () > 1 ?
198+ blob->count () / blob->shape (1 ) :
199+ blob->count ();
196200 Dtype n = fan_in; // default to fan_in
197201 if (this ->filler_param_ .variance_norm () ==
198202 FillerParameter_VarianceNorm_AVERAGE) {
You can’t perform that action at this time.
0 commit comments