@@ -3162,44 +3162,19 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
3162
3162
return np .allclose (x , ref , rtol = rtol , atol = atol )
3163
3163
3164
3164
3165
- def _skip_mul_1 (r ):
3166
- if r .owner and r .owner .op == mul :
3167
- not_is_1 = [i for i in r .owner .inputs if not _is_1 (i )]
3168
- if len (not_is_1 ) == 1 :
3169
- return not_is_1 [0 ]
3170
-
3171
-
3172
- def _is_1 (expr ):
3173
- """
3174
-
3175
- Returns
3176
- -------
3177
- bool
3178
- True iff expr is a constant close to 1.
3179
-
3180
- """
3181
- try :
3182
- v = get_underlying_scalar_constant_value (expr )
3183
- return isclose (v , 1 )
3184
- except NotScalarConstantError :
3185
- return False
3186
-
3187
-
3188
3165
logsigm_to_softplus = PatternNodeRewriter (
3189
3166
(log , (sigmoid , "x" )),
3190
3167
(neg , (softplus , (neg , "x" ))),
3191
3168
allow_multiple_clients = True ,
3192
3169
values_eq_approx = values_eq_approx_remove_inf ,
3193
- skip_identities_fn = _skip_mul_1 ,
3194
3170
tracks = [sigmoid ],
3195
3171
get_nodes = get_clients_at_depth1 ,
3196
3172
)
3197
3173
log1msigm_to_softplus = PatternNodeRewriter (
3198
- (log , (sub , dict ( pattern = "y" , constraint = _is_1 ) , (sigmoid , "x" ))),
3174
+ (log , (sub , 1 , (sigmoid , "x" ))),
3199
3175
(neg , (softplus , "x" )),
3200
3176
allow_multiple_clients = True ,
3201
3177
values_eq_approx = values_eq_approx_remove_inf ,
3202
- skip_identities_fn = _skip_mul_1 ,
3203
3178
tracks = [sigmoid ],
3204
3179
get_nodes = get_clients_at_depth2 ,
3205
3180
)
@@ -3396,10 +3371,8 @@ def local_exp_over_1_plus_exp(fgraph, node):
3396
3371
3397
3372
if len (denom_rest ) == 0 :
3398
3373
return [new_num ]
3399
- elif len (denom_rest ) == 1 :
3400
- out = new_num / denom_rest [0 ]
3401
3374
else :
3402
- out = new_num / mul (* denom_rest )
3375
+ out = new_num / variadic_mul (* denom_rest )
3403
3376
3404
3377
copy_stack_trace (node .outputs [0 ], out )
3405
3378
return [out ]
@@ -3769,7 +3742,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
3769
3742
3770
3743
# 1 - sigmoid(x) -> sigmoid(-x)
3771
3744
local_1msigmoid = PatternNodeRewriter (
3772
- (sub , dict ( pattern = "y" , constraint = _is_1 ) , (sigmoid , "x" )),
3745
+ (sub , 1 , (sigmoid , "x" )),
3773
3746
(sigmoid , (neg , "x" )),
3774
3747
tracks = [sigmoid ],
3775
3748
get_nodes = get_clients_at_depth1 ,
0 commit comments