File tree 1 file changed +8
-4
lines changed
1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -148,8 +148,10 @@ class XavierFiller : public Filler<Dtype> {
148
148
virtual void Fill (Blob<Dtype>* blob) {
149
149
CHECK (blob->count ());
150
150
int fan_in = blob->count () / blob->shape (0 );
151
- // Compatible for ND Convolution
152
- int fan_out = blob->count () / blob->shape (1 );
151
+ // Compatibility with ND blobs
152
+ int fan_out = blob->num_axes () > 1 ?
153
+ blob->count () / blob->shape (1 ) :
154
+ blob->count ();
153
155
Dtype n = fan_in; // default to fan_in
154
156
if (this ->filler_param_ .variance_norm () ==
155
157
FillerParameter_VarianceNorm_AVERAGE) {
@@ -191,8 +193,10 @@ class MSRAFiller : public Filler<Dtype> {
191
193
virtual void Fill (Blob<Dtype>* blob) {
192
194
CHECK (blob->count ());
193
195
int fan_in = blob->count () / blob->shape (0 );
194
- // Compatible for ND Convolution
195
- int fan_out = blob->count () / blob->shape (1 );
196
+ // Compatibility with ND blobs
197
+ int fan_out = blob->num_axes () > 1 ?
198
+ blob->count () / blob->shape (1 ) :
199
+ blob->count ();
196
200
Dtype n = fan_in; // default to fan_in
197
201
if (this ->filler_param_ .variance_norm () ==
198
202
FillerParameter_VarianceNorm_AVERAGE) {
You can’t perform that action at this time.
0 commit comments