Skip to content

Commit ced33f6

Browse files
authored
Merge pull request #1195 from WardBrian/vectorize-abs-deprecate-fabs
Vectorize abs, deprecate fabs
2 parents 4be341e + 2bb13f0 commit ced33f6

31 files changed

+644
-512
lines changed

src/analysis_and_optimization/Partial_evaluator.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
335335
| ( "log"
336336
, [ { pattern=
337337
FunApp
338-
( StanLib ("fabs", FnPlain, mem1)
338+
( StanLib (("fabs" | "abs"), FnPlain, mem1)
339339
, [ { pattern=
340340
FunApp
341341
( StanLib ("determinant", FnPlain, mem2)

src/frontend/Canonicalize.ml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,6 @@ let rec replace_deprecated_expr
3535
let expr =
3636
match expr with
3737
| GetLP -> GetTarget
38-
| FunApp (StanLib FnPlain, {name= "abs"; id_loc}, [e])
39-
when Middle.UnsizedType.is_real_type e.emeta.type_ ->
40-
FunApp
41-
( StanLib FnPlain
42-
, {name= "fabs"; id_loc}
43-
, [replace_deprecated_expr deprecated_userdefined e] )
4438
| FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) ->
4539
Paren
4640
(replace_deprecated_expr deprecated_userdefined

src/frontend/Deprecation_analysis.ml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ let deprecated_functions =
66
String.Map.of_alist_exn
77
[ ("multiply_log", ("lmultiply", "2.32.0"))
88
; ("binomial_coefficient_log", ("lchoose", "2.32.0"))
9-
; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ]
9+
; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0"))
10+
; ("fabs", ("abs", "2.33.0")) ]
1011

1112
let deprecated_odes =
1213
String.Map.of_alist_exn
@@ -26,7 +27,7 @@ let deprecated_distributions =
2627
| Lpmf -> Some (name ^ "_log", name ^ "_lpmf")
2728
| Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf")
2829
| Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf")
29-
| Rng | Log | UnaryVectorized -> None ) ) ) )
30+
| Rng | Log | UnaryVectorized _ -> None ) ) ) )
3031

3132
let stan_lib_deprecations =
3233
Map.merge_skewed deprecated_distributions deprecated_functions
@@ -100,14 +101,6 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
100101
({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) :
101102
(Location_span.t * string) list =
102103
match expr with
103-
| FunApp (StanLib FnPlain, {name= "abs"; _}, [e])
104-
when Middle.UnsizedType.is_real_type e.emeta.type_ ->
105-
collect_deprecated_expr
106-
( acc
107-
@ [ ( emeta.loc
108-
, "Use of the `abs` function with real-valued arguments is \
109-
deprecated; use function `fabs` instead." ) ] )
110-
e
111104
| FunApp (StanLib FnPlain, {name= "if_else"; _}, l) ->
112105
acc
113106
@ [ ( emeta.loc

src/middle/Stan_math_signatures.ml

Lines changed: 82 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ type dimensionality =
3737
let rec bare_array_type (t, i) =
3838
match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1))
3939

40+
let bare_types =
41+
[ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix
42+
; UComplexVector; UComplexRowVector; UComplexMatrix ]
43+
44+
let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector]
45+
let primitive_types = [UnsizedType.UInt; UReal]
46+
47+
let complex_types =
48+
[UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix]
49+
50+
let all_vector_types =
51+
[UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt]
52+
4053
let rec expand_arg = function
4154
| DInt -> [UnsizedType.UInt]
4255
| DReal -> [UReal]
@@ -57,21 +70,21 @@ let rec expand_arg = function
5770
concat_map all_base ~f:(fun a ->
5871
map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) ))
5972
| DDeepComplexVectorized ->
60-
let all_base =
61-
[UnsizedType.UComplex; UComplexRowVector; UComplexVector; UComplexMatrix]
62-
in
6373
List.(
64-
concat_map all_base ~f:(fun a ->
74+
concat_map complex_types ~f:(fun a ->
6575
map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) ))
6676

77+
type return_behavior = SameAsArg | IntsToReals | ComplexToReals
78+
[@@deriving show {with_path= false}]
79+
6780
type fkind =
6881
| Lpmf
6982
| Lpdf
70-
| Log [@printer fun fmt _ -> fprintf fmt "log (deprecated)"]
83+
| Log [@printer fun fmt _ -> fprintf fmt "Log (deprecated)"]
7184
| Rng
7285
| Cdf
7386
| Ccdf
74-
| UnaryVectorized
87+
| UnaryVectorized of return_behavior
7588
[@@deriving show {with_path= false}]
7689

7790
type fun_arg = UnsizedType.autodifftype * UnsizedType.t
@@ -199,7 +212,7 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
199212
| Rng -> ["_rng"]
200213
| Cdf -> ["_cdf"; "_cdf_log"; "_lcdf"]
201214
| Ccdf -> ["_ccdf_log"; "_lccdf"]
202-
| UnaryVectorized -> [""] in
215+
| UnaryVectorized _ -> [""] in
203216
let add_ints = function DVReal -> DIntAndReals | x -> x in
204217
let all_expanded args = all_combinations (List.map ~f:expand_arg args) in
205218
let promoted_dim = function
@@ -208,7 +221,11 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
208221
| _ -> UReal in
209222
let find_rt rt args = function
210223
| Rng -> UnsizedType.ReturnType (rng_return_type rt args)
211-
| UnaryVectorized -> ReturnType (ints_to_real (List.hd_exn args))
224+
| UnaryVectorized SameAsArg -> ReturnType (List.hd_exn args)
225+
| UnaryVectorized IntsToReals ->
226+
ReturnType (ints_to_real (List.hd_exn args))
227+
| UnaryVectorized ComplexToReals ->
228+
ReturnType (complex_to_real (List.hd_exn args))
212229
| _ -> ReturnType UReal in
213230
let create_from_fk_args fk arglists =
214231
List.concat_map arglists ~f:(fun args ->
@@ -222,7 +239,6 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
222239
let name = name ^ "_rng" in
223240
List.map (all_expanded args) ~f:(fun args ->
224241
(name, find_rt rt args Rng, args, mem_pattern) )
225-
| UnaryVectorized -> create_from_fk_args UnaryVectorized (all_expanded args)
226242
| fk -> create_from_fk_args fk (all_expanded args) in
227243
List.concat_map fnkinds ~f:add_fnkind
228244
|> List.filter ~f:(fun (n, _, _, _) -> not (Set.mem missing_math_functions n))
@@ -344,59 +360,64 @@ let distributions =
344360
; ([Lpdf], "wishart_cholesky", [DMatrix; DReal; DMatrix], SoA)
345361
; ([Lpdf; Log], "wishart", [DMatrix; DReal; DMatrix], SoA) ]
346362

363+
let basic_vectorized = UnaryVectorized IntsToReals
364+
347365
let math_sigs =
348-
[ ([UnaryVectorized], "acos", [DDeepVectorized], Common.Helpers.SoA)
349-
; ([UnaryVectorized], "acosh", [DDeepVectorized], SoA)
350-
; ([UnaryVectorized], "asin", [DDeepVectorized], SoA)
351-
; ([UnaryVectorized], "asinh", [DDeepVectorized], SoA)
352-
; ([UnaryVectorized], "atan", [DDeepVectorized], SoA)
353-
; ([UnaryVectorized], "atanh", [DDeepVectorized], SoA)
354-
; ([UnaryVectorized], "cbrt", [DDeepVectorized], SoA)
355-
; ([UnaryVectorized], "ceil", [DDeepVectorized], SoA)
356-
; ([UnaryVectorized], "cos", [DDeepVectorized], SoA)
357-
; ([UnaryVectorized], "cosh", [DDeepVectorized], SoA)
358-
; ([UnaryVectorized], "digamma", [DDeepVectorized], SoA)
359-
; ([UnaryVectorized], "erf", [DDeepVectorized], SoA)
360-
; ([UnaryVectorized], "erfc", [DDeepVectorized], SoA)
361-
; ([UnaryVectorized], "exp", [DDeepVectorized], SoA)
362-
; ([UnaryVectorized], "exp2", [DDeepVectorized], SoA)
363-
; ([UnaryVectorized], "expm1", [DDeepVectorized], SoA)
364-
; ([UnaryVectorized], "fabs", [DDeepVectorized], SoA)
365-
; ([UnaryVectorized], "floor", [DDeepVectorized], SoA)
366-
; ([UnaryVectorized], "inv", [DDeepVectorized], SoA)
367-
; ([UnaryVectorized], "inv_cloglog", [DDeepVectorized], SoA)
368-
; ([UnaryVectorized], "inv_erfc", [DDeepVectorized], SoA)
369-
; ([UnaryVectorized], "inv_logit", [DDeepVectorized], SoA)
370-
; ([UnaryVectorized], "inv_Phi", [DDeepVectorized], SoA)
371-
; ([UnaryVectorized], "inv_sqrt", [DDeepVectorized], SoA)
372-
; ([UnaryVectorized], "inv_square", [DDeepVectorized], SoA)
373-
; ([UnaryVectorized], "lambert_w0", [DDeepVectorized], SoA)
374-
; ([UnaryVectorized], "lambert_wm1", [DDeepVectorized], SoA)
375-
; ([UnaryVectorized], "lgamma", [DDeepVectorized], SoA)
376-
; ([UnaryVectorized], "log", [DDeepVectorized], SoA)
377-
; ([UnaryVectorized], "log10", [DDeepVectorized], SoA)
378-
; ([UnaryVectorized], "log1m", [DDeepVectorized], SoA)
379-
; ([UnaryVectorized], "log1m_exp", [DDeepVectorized], SoA)
380-
; ([UnaryVectorized], "log1m_inv_logit", [DDeepVectorized], SoA)
381-
; ([UnaryVectorized], "log1p", [DDeepVectorized], SoA)
382-
; ([UnaryVectorized], "log1p_exp", [DDeepVectorized], SoA)
383-
; ([UnaryVectorized], "log2", [DDeepVectorized], SoA)
384-
; ([UnaryVectorized], "log_inv_logit", [DDeepVectorized], SoA)
385-
; ([UnaryVectorized], "logit", [DDeepVectorized], SoA)
386-
; ([UnaryVectorized], "Phi", [DDeepVectorized], SoA)
387-
; ([UnaryVectorized], "Phi_approx", [DDeepVectorized], SoA)
388-
; ([UnaryVectorized], "round", [DDeepVectorized], SoA)
389-
; ([UnaryVectorized], "sin", [DDeepVectorized], SoA)
390-
; ([UnaryVectorized], "sinh", [DDeepVectorized], SoA)
391-
; ([UnaryVectorized], "sqrt", [DDeepVectorized], SoA)
392-
; ([UnaryVectorized], "square", [DDeepVectorized], SoA)
393-
; ([UnaryVectorized], "step", [DReal], SoA)
394-
; ([UnaryVectorized], "tan", [DDeepVectorized], SoA)
395-
; ([UnaryVectorized], "tanh", [DDeepVectorized], SoA)
396-
(* ; add_nullary ("target") *)
397-
; ([UnaryVectorized], "tgamma", [DDeepVectorized], SoA)
398-
; ([UnaryVectorized], "trunc", [DDeepVectorized], SoA)
399-
; ([UnaryVectorized], "trigamma", [DDeepVectorized], SoA) ]
366+
[ ([basic_vectorized], "acos", [DDeepVectorized], Common.Helpers.SoA)
367+
; ([basic_vectorized], "acosh", [DDeepVectorized], SoA)
368+
; ([basic_vectorized], "asin", [DDeepVectorized], SoA)
369+
; ([basic_vectorized], "asinh", [DDeepVectorized], SoA)
370+
; ([basic_vectorized], "atan", [DDeepVectorized], SoA)
371+
; ([basic_vectorized], "atanh", [DDeepVectorized], SoA)
372+
; ([basic_vectorized], "cbrt", [DDeepVectorized], SoA)
373+
; ([basic_vectorized], "ceil", [DDeepVectorized], SoA)
374+
; ([basic_vectorized], "cos", [DDeepVectorized], SoA)
375+
; ([basic_vectorized], "cosh", [DDeepVectorized], SoA)
376+
; ([basic_vectorized], "digamma", [DDeepVectorized], SoA)
377+
; ([basic_vectorized], "erf", [DDeepVectorized], SoA)
378+
; ([basic_vectorized], "erfc", [DDeepVectorized], SoA)
379+
; ([basic_vectorized], "exp", [DDeepVectorized], SoA)
380+
; ([basic_vectorized], "exp2", [DDeepVectorized], SoA)
381+
; ([basic_vectorized], "expm1", [DDeepVectorized], SoA)
382+
; ([basic_vectorized], "fabs", [DDeepVectorized], SoA)
383+
; ([UnaryVectorized ComplexToReals], "get_imag", [DDeepComplexVectorized], AoS)
384+
; ([UnaryVectorized ComplexToReals], "get_real", [DDeepComplexVectorized], AoS)
385+
; ([UnaryVectorized SameAsArg], "abs", [DDeepVectorized], SoA)
386+
; ([UnaryVectorized ComplexToReals], "abs", [DDeepComplexVectorized], AoS)
387+
; ([basic_vectorized], "floor", [DDeepVectorized], SoA)
388+
; ([basic_vectorized], "inv", [DDeepVectorized], SoA)
389+
; ([basic_vectorized], "inv_cloglog", [DDeepVectorized], SoA)
390+
; ([basic_vectorized], "inv_erfc", [DDeepVectorized], SoA)
391+
; ([basic_vectorized], "inv_logit", [DDeepVectorized], SoA)
392+
; ([basic_vectorized], "inv_Phi", [DDeepVectorized], SoA)
393+
; ([basic_vectorized], "inv_sqrt", [DDeepVectorized], SoA)
394+
; ([basic_vectorized], "inv_square", [DDeepVectorized], SoA)
395+
; ([basic_vectorized], "lambert_w0", [DDeepVectorized], SoA)
396+
; ([basic_vectorized], "lambert_wm1", [DDeepVectorized], SoA)
397+
; ([basic_vectorized], "lgamma", [DDeepVectorized], SoA)
398+
; ([basic_vectorized], "log", [DDeepVectorized], SoA)
399+
; ([basic_vectorized], "log10", [DDeepVectorized], SoA)
400+
; ([basic_vectorized], "log1m", [DDeepVectorized], SoA)
401+
; ([basic_vectorized], "log1m_exp", [DDeepVectorized], SoA)
402+
; ([basic_vectorized], "log1m_inv_logit", [DDeepVectorized], SoA)
403+
; ([basic_vectorized], "log1p", [DDeepVectorized], SoA)
404+
; ([basic_vectorized], "log1p_exp", [DDeepVectorized], SoA)
405+
; ([basic_vectorized], "log2", [DDeepVectorized], SoA)
406+
; ([basic_vectorized], "log_inv_logit", [DDeepVectorized], SoA)
407+
; ([basic_vectorized], "logit", [DDeepVectorized], SoA)
408+
; ([basic_vectorized], "Phi", [DDeepVectorized], SoA)
409+
; ([basic_vectorized], "Phi_approx", [DDeepVectorized], SoA)
410+
; ([basic_vectorized], "round", [DDeepVectorized], SoA)
411+
; ([basic_vectorized], "sin", [DDeepVectorized], SoA)
412+
; ([basic_vectorized], "sinh", [DDeepVectorized], SoA)
413+
; ([basic_vectorized], "sqrt", [DDeepVectorized], SoA)
414+
; ([basic_vectorized], "square", [DDeepVectorized], SoA)
415+
; ([basic_vectorized], "step", [DReal], SoA)
416+
; ([basic_vectorized], "tan", [DDeepVectorized], SoA)
417+
; ([basic_vectorized], "tanh", [DDeepVectorized], SoA)
418+
; ([basic_vectorized], "tgamma", [DDeepVectorized], SoA)
419+
; ([basic_vectorized], "trunc", [DDeepVectorized], SoA)
420+
; ([basic_vectorized], "trigamma", [DDeepVectorized], SoA) ]
400421

401422
let all_declarative_sigs = distributions @ math_sigs
402423

@@ -535,18 +556,6 @@ let pretty_print_math_lib_operator_sigs op =
535556
else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs
536557

537558
(* -- Some helper definitions to populate stan_math_signatures -- *)
538-
let bare_types =
539-
[ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix
540-
; UComplexVector; UComplexRowVector; UComplexMatrix ]
541-
542-
let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector]
543-
let primitive_types = [UnsizedType.UInt; UReal]
544-
545-
let complex_types =
546-
[UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix]
547-
548-
let all_vector_types =
549-
[UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt]
550559

551560
let add_qualified (name, rt, argts, supports_soa) =
552561
Hashtbl.add_multi stan_math_signatures ~key:name
@@ -863,9 +872,6 @@ let for_vector_types s = List.iter ~f:s vector_types
863872
let () =
864873
List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) ->
865874
Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ;
866-
add_unqualified ("abs", ReturnType UInt, [UInt], SoA) ;
867-
add_unqualified ("abs", ReturnType UReal, [UReal], SoA) ;
868-
add_unqualified ("abs", ReturnType UReal, [UComplex], AoS) ;
869875
add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ;
870876
add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ;
871877
List.iter
@@ -1292,28 +1298,6 @@ let () =
12921298
, ReturnType UReal
12931299
, [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix]
12941300
, AoS ) ;
1295-
List.iter
1296-
~f:(fun i ->
1297-
List.iter
1298-
~f:(fun t ->
1299-
add_unqualified
1300-
( "get_imag"
1301-
, ReturnType (bare_array_type (complex_to_real t, i))
1302-
, [bare_array_type (t, i)]
1303-
, AoS ) )
1304-
complex_types )
1305-
(List.range 0 8) ;
1306-
List.iter
1307-
~f:(fun i ->
1308-
List.iter
1309-
~f:(fun t ->
1310-
add_unqualified
1311-
( "get_real"
1312-
, ReturnType (bare_array_type (complex_to_real t, i))
1313-
, [bare_array_type (t, i)]
1314-
, AoS ) )
1315-
complex_types )
1316-
(List.range 0 8) ;
13171301
add_unqualified
13181302
("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS) ;
13191303
add_unqualified

src/middle/Stan_math_signatures.mli

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@ val pretty_print_all_math_sigs : unit Fmt.t
2828
val pretty_print_all_math_distributions : unit Fmt.t
2929

3030
type dimensionality
31-
32-
type fkind = Lpmf | Lpdf | Log | Rng | Cdf | Ccdf | UnaryVectorized
31+
type return_behavior
32+
33+
type fkind = private
34+
| Lpmf
35+
| Lpdf
36+
| Log
37+
| Rng
38+
| Cdf
39+
| Ccdf
40+
| UnaryVectorized of return_behavior
3341
[@@deriving show {with_path= false}]
3442

3543
val distributions :

test/integration/cli-args/canonicalize/canonical.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ transformed data {
179179
int a = -12;
180180
real b = 1.5;
181181
int c = abs(a);
182-
real d = fabs(b);
182+
real d = abs(b);
183183
array[0] int x_i;
184184
array[0] real x_r;
185185
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
@@ -287,8 +287,8 @@ model {
287287
*/
288288
parameters {
289289
real<lower=-1, upper=1> x_raw;
290-
real<lower=-(1 - sqrt(1 - square(1 - fabs(x_raw)))),
291-
upper=(1 - sqrt(1 - square(1 - fabs(x_raw))))> y_raw;
290+
real<lower=-(1 - sqrt(1 - square(1 - abs(x_raw)))),
291+
upper=(1 - sqrt(1 - square(1 - abs(x_raw))))> y_raw;
292292
}
293293
transformed parameters {
294294
real<lower=-1, upper=1> x;
@@ -297,6 +297,6 @@ transformed parameters {
297297
y = ((y_raw > 0) ? 1 : -1) - y_raw;
298298
}
299299
model {
300-
target += log1m(sqrt(1 - square(1 - fabs(x_raw))));
300+
target += log1m(sqrt(1 - square(1 - abs(x_raw))));
301301
}
302302

test/integration/cli-args/canonicalize/deprecations-only.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ transformed data {
159159
int a = -12;
160160
real b = 1.5;
161161
int c = abs(a);
162-
real d = fabs(b);
162+
real d = abs(b);
163163
array[0] int x_i;
164164
array[0] real x_r;
165165
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
@@ -266,8 +266,8 @@ model {
266266
*/
267267
parameters {
268268
real<lower=-1, upper=1> x_raw;
269-
real<lower=-(1 - sqrt(1 - square(1 - fabs(x_raw)))),
270-
upper=(1 - sqrt(1 - square(1 - fabs(x_raw))))> y_raw;
269+
real<lower=-(1 - sqrt(1 - square(1 - abs(x_raw)))),
270+
upper=(1 - sqrt(1 - square(1 - abs(x_raw))))> y_raw;
271271
}
272272
transformed parameters {
273273
real<lower=-1, upper=1> x;
@@ -276,6 +276,6 @@ transformed parameters {
276276
y = ((y_raw > 0) ? 1 : -1) - y_raw;
277277
}
278278
model {
279-
target += log1m(sqrt(1 - square(1 - fabs(x_raw))));
279+
target += log1m(sqrt(1 - square(1 - abs(x_raw))));
280280
}
281281

0 commit comments

Comments
 (0)