Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
| ( "log"
, [ { pattern=
FunApp
( StanLib ("fabs", FnPlain, mem1)
( StanLib (("fabs" | "abs"), FnPlain, mem1)
, [ { pattern=
FunApp
( StanLib ("determinant", FnPlain, mem2)
Expand Down
6 changes: 0 additions & 6 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ let rec replace_deprecated_expr
let expr =
match expr with
| GetLP -> GetTarget
| FunApp (StanLib FnPlain, {name= "abs"; id_loc}, [e])
when Middle.UnsizedType.is_real_type e.emeta.type_ ->
FunApp
( StanLib FnPlain
, {name= "fabs"; id_loc}
, [replace_deprecated_expr deprecated_userdefined e] )
| FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) ->
Paren
(replace_deprecated_expr deprecated_userdefined
Expand Down
13 changes: 3 additions & 10 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ let deprecated_functions =
String.Map.of_alist_exn
[ ("multiply_log", ("lmultiply", "2.32.0"))
; ("binomial_coefficient_log", ("lchoose", "2.32.0"))
; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ]
; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0"))
; ("fabs", ("abs", "2.33.0")) ]

let deprecated_odes =
String.Map.of_alist_exn
Expand All @@ -26,7 +27,7 @@ let deprecated_distributions =
| Lpmf -> Some (name ^ "_log", name ^ "_lpmf")
| Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf")
| Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf")
| Rng | Log | UnaryVectorized -> None ) ) ) )
| Rng | Log | UnaryVectorized _ -> None ) ) ) )

let stan_lib_deprecations =
Map.merge_skewed deprecated_distributions deprecated_functions
Expand Down Expand Up @@ -100,14 +101,6 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) :
(Location_span.t * string) list =
match expr with
| FunApp (StanLib FnPlain, {name= "abs"; _}, [e])
when Middle.UnsizedType.is_real_type e.emeta.type_ ->
collect_deprecated_expr
( acc
@ [ ( emeta.loc
, "Use of the `abs` function with real-valued arguments is \
deprecated; use function `fabs` instead." ) ] )
e
| FunApp (StanLib FnPlain, {name= "if_else"; _}, l) ->
acc
@ [ ( emeta.loc
Expand Down
180 changes: 82 additions & 98 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ type dimensionality =
let rec bare_array_type (t, i) =
match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1))

let bare_types =
[ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix
; UComplexVector; UComplexRowVector; UComplexMatrix ]

let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector]
let primitive_types = [UnsizedType.UInt; UReal]

let complex_types =
[UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix]

let all_vector_types =
[UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt]

let rec expand_arg = function
| DInt -> [UnsizedType.UInt]
| DReal -> [UReal]
Expand All @@ -57,21 +70,21 @@ let rec expand_arg = function
concat_map all_base ~f:(fun a ->
map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) ))
| DDeepComplexVectorized ->
let all_base =
[UnsizedType.UComplex; UComplexRowVector; UComplexVector; UComplexMatrix]
in
List.(
concat_map all_base ~f:(fun a ->
concat_map complex_types ~f:(fun a ->
map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) ))

type return_behavior = SameAsArg | IntsToReals | ComplexToReals
[@@deriving show {with_path= false}]

type fkind =
| Lpmf
| Lpdf
| Log [@printer fun fmt _ -> fprintf fmt "log (deprecated)"]
| Log [@printer fun fmt _ -> fprintf fmt "Log (deprecated)"]
| Rng
| Cdf
| Ccdf
| UnaryVectorized
| UnaryVectorized of return_behavior
[@@deriving show {with_path= false}]

type fun_arg = UnsizedType.autodifftype * UnsizedType.t
Expand Down Expand Up @@ -199,7 +212,7 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
| Rng -> ["_rng"]
| Cdf -> ["_cdf"; "_cdf_log"; "_lcdf"]
| Ccdf -> ["_ccdf_log"; "_lccdf"]
| UnaryVectorized -> [""] in
| UnaryVectorized _ -> [""] in
let add_ints = function DVReal -> DIntAndReals | x -> x in
let all_expanded args = all_combinations (List.map ~f:expand_arg args) in
let promoted_dim = function
Expand All @@ -208,7 +221,11 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
| _ -> UReal in
let find_rt rt args = function
| Rng -> UnsizedType.ReturnType (rng_return_type rt args)
| UnaryVectorized -> ReturnType (ints_to_real (List.hd_exn args))
| UnaryVectorized SameAsArg -> ReturnType (List.hd_exn args)
| UnaryVectorized IntsToReals ->
ReturnType (ints_to_real (List.hd_exn args))
| UnaryVectorized ComplexToReals ->
ReturnType (complex_to_real (List.hd_exn args))
| _ -> ReturnType UReal in
let create_from_fk_args fk arglists =
List.concat_map arglists ~f:(fun args ->
Expand All @@ -222,7 +239,6 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
let name = name ^ "_rng" in
List.map (all_expanded args) ~f:(fun args ->
(name, find_rt rt args Rng, args, mem_pattern) )
| UnaryVectorized -> create_from_fk_args UnaryVectorized (all_expanded args)
| fk -> create_from_fk_args fk (all_expanded args) in
List.concat_map fnkinds ~f:add_fnkind
|> List.filter ~f:(fun (n, _, _, _) -> not (Set.mem missing_math_functions n))
Expand Down Expand Up @@ -341,59 +357,64 @@ let distributions =
; ([Lpdf], "wishart_cholesky", [DMatrix; DReal; DMatrix], SoA)
; ([Lpdf; Log], "wishart", [DMatrix; DReal; DMatrix], SoA) ]

let basic_vectorized = UnaryVectorized IntsToReals

let math_sigs =
[ ([UnaryVectorized], "acos", [DDeepVectorized], Common.Helpers.SoA)
; ([UnaryVectorized], "acosh", [DDeepVectorized], SoA)
; ([UnaryVectorized], "asin", [DDeepVectorized], SoA)
; ([UnaryVectorized], "asinh", [DDeepVectorized], SoA)
; ([UnaryVectorized], "atan", [DDeepVectorized], SoA)
; ([UnaryVectorized], "atanh", [DDeepVectorized], SoA)
; ([UnaryVectorized], "cbrt", [DDeepVectorized], SoA)
; ([UnaryVectorized], "ceil", [DDeepVectorized], SoA)
; ([UnaryVectorized], "cos", [DDeepVectorized], SoA)
; ([UnaryVectorized], "cosh", [DDeepVectorized], SoA)
; ([UnaryVectorized], "digamma", [DDeepVectorized], SoA)
; ([UnaryVectorized], "erf", [DDeepVectorized], SoA)
; ([UnaryVectorized], "erfc", [DDeepVectorized], SoA)
; ([UnaryVectorized], "exp", [DDeepVectorized], SoA)
; ([UnaryVectorized], "exp2", [DDeepVectorized], SoA)
; ([UnaryVectorized], "expm1", [DDeepVectorized], SoA)
; ([UnaryVectorized], "fabs", [DDeepVectorized], SoA)
; ([UnaryVectorized], "floor", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_cloglog", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_erfc", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_logit", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_Phi", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_sqrt", [DDeepVectorized], SoA)
; ([UnaryVectorized], "inv_square", [DDeepVectorized], SoA)
; ([UnaryVectorized], "lambert_w0", [DDeepVectorized], SoA)
; ([UnaryVectorized], "lambert_wm1", [DDeepVectorized], SoA)
; ([UnaryVectorized], "lgamma", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log10", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log1m", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log1m_exp", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log1m_inv_logit", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log1p", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log1p_exp", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log2", [DDeepVectorized], SoA)
; ([UnaryVectorized], "log_inv_logit", [DDeepVectorized], SoA)
; ([UnaryVectorized], "logit", [DDeepVectorized], SoA)
; ([UnaryVectorized], "Phi", [DDeepVectorized], SoA)
; ([UnaryVectorized], "Phi_approx", [DDeepVectorized], SoA)
; ([UnaryVectorized], "round", [DDeepVectorized], SoA)
; ([UnaryVectorized], "sin", [DDeepVectorized], SoA)
; ([UnaryVectorized], "sinh", [DDeepVectorized], SoA)
; ([UnaryVectorized], "sqrt", [DDeepVectorized], SoA)
; ([UnaryVectorized], "square", [DDeepVectorized], SoA)
; ([UnaryVectorized], "step", [DReal], SoA)
; ([UnaryVectorized], "tan", [DDeepVectorized], SoA)
; ([UnaryVectorized], "tanh", [DDeepVectorized], SoA)
(* ; add_nullary ("target") *)
; ([UnaryVectorized], "tgamma", [DDeepVectorized], SoA)
; ([UnaryVectorized], "trunc", [DDeepVectorized], SoA)
; ([UnaryVectorized], "trigamma", [DDeepVectorized], SoA) ]
[ ([basic_vectorized], "acos", [DDeepVectorized], Common.Helpers.SoA)
; ([basic_vectorized], "acosh", [DDeepVectorized], SoA)
; ([basic_vectorized], "asin", [DDeepVectorized], SoA)
; ([basic_vectorized], "asinh", [DDeepVectorized], SoA)
; ([basic_vectorized], "atan", [DDeepVectorized], SoA)
; ([basic_vectorized], "atanh", [DDeepVectorized], SoA)
; ([basic_vectorized], "cbrt", [DDeepVectorized], SoA)
; ([basic_vectorized], "ceil", [DDeepVectorized], SoA)
; ([basic_vectorized], "cos", [DDeepVectorized], SoA)
; ([basic_vectorized], "cosh", [DDeepVectorized], SoA)
; ([basic_vectorized], "digamma", [DDeepVectorized], SoA)
; ([basic_vectorized], "erf", [DDeepVectorized], SoA)
; ([basic_vectorized], "erfc", [DDeepVectorized], SoA)
; ([basic_vectorized], "exp", [DDeepVectorized], SoA)
; ([basic_vectorized], "exp2", [DDeepVectorized], SoA)
; ([basic_vectorized], "expm1", [DDeepVectorized], SoA)
; ([basic_vectorized], "fabs", [DDeepVectorized], SoA)
; ([UnaryVectorized ComplexToReals], "get_imag", [DDeepComplexVectorized], AoS)
; ([UnaryVectorized ComplexToReals], "get_real", [DDeepComplexVectorized], AoS)
; ([UnaryVectorized SameAsArg], "abs", [DDeepVectorized], SoA)
; ([UnaryVectorized ComplexToReals], "abs", [DDeepComplexVectorized], AoS)
; ([basic_vectorized], "floor", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_cloglog", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_erfc", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_logit", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_Phi", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_sqrt", [DDeepVectorized], SoA)
; ([basic_vectorized], "inv_square", [DDeepVectorized], SoA)
; ([basic_vectorized], "lambert_w0", [DDeepVectorized], SoA)
; ([basic_vectorized], "lambert_wm1", [DDeepVectorized], SoA)
; ([basic_vectorized], "lgamma", [DDeepVectorized], SoA)
; ([basic_vectorized], "log", [DDeepVectorized], SoA)
; ([basic_vectorized], "log10", [DDeepVectorized], SoA)
; ([basic_vectorized], "log1m", [DDeepVectorized], SoA)
; ([basic_vectorized], "log1m_exp", [DDeepVectorized], SoA)
; ([basic_vectorized], "log1m_inv_logit", [DDeepVectorized], SoA)
; ([basic_vectorized], "log1p", [DDeepVectorized], SoA)
; ([basic_vectorized], "log1p_exp", [DDeepVectorized], SoA)
; ([basic_vectorized], "log2", [DDeepVectorized], SoA)
; ([basic_vectorized], "log_inv_logit", [DDeepVectorized], SoA)
; ([basic_vectorized], "logit", [DDeepVectorized], SoA)
; ([basic_vectorized], "Phi", [DDeepVectorized], SoA)
; ([basic_vectorized], "Phi_approx", [DDeepVectorized], SoA)
; ([basic_vectorized], "round", [DDeepVectorized], SoA)
; ([basic_vectorized], "sin", [DDeepVectorized], SoA)
; ([basic_vectorized], "sinh", [DDeepVectorized], SoA)
; ([basic_vectorized], "sqrt", [DDeepVectorized], SoA)
; ([basic_vectorized], "square", [DDeepVectorized], SoA)
; ([basic_vectorized], "step", [DReal], SoA)
; ([basic_vectorized], "tan", [DDeepVectorized], SoA)
; ([basic_vectorized], "tanh", [DDeepVectorized], SoA)
; ([basic_vectorized], "tgamma", [DDeepVectorized], SoA)
; ([basic_vectorized], "trunc", [DDeepVectorized], SoA)
; ([basic_vectorized], "trigamma", [DDeepVectorized], SoA) ]

let all_declarative_sigs = distributions @ math_sigs

Expand Down Expand Up @@ -532,18 +553,6 @@ let pretty_print_math_lib_operator_sigs op =
else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs

(* -- Some helper definitions to populate stan_math_signatures -- *)
let bare_types =
[ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix
; UComplexVector; UComplexRowVector; UComplexMatrix ]

let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector]
let primitive_types = [UnsizedType.UInt; UReal]

let complex_types =
[UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix]

let all_vector_types =
[UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt]

let add_qualified (name, rt, argts, supports_soa) =
Hashtbl.add_multi stan_math_signatures ~key:name
Expand Down Expand Up @@ -860,9 +869,6 @@ let for_vector_types s = List.iter ~f:s vector_types
let () =
List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) ->
Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ;
add_unqualified ("abs", ReturnType UInt, [UInt], SoA) ;
add_unqualified ("abs", ReturnType UReal, [UReal], SoA) ;
add_unqualified ("abs", ReturnType UReal, [UComplex], AoS) ;
add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ;
add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ;
List.iter
Expand Down Expand Up @@ -1289,28 +1295,6 @@ let () =
, ReturnType UReal
, [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix]
, AoS ) ;
List.iter
~f:(fun i ->
List.iter
~f:(fun t ->
add_unqualified
( "get_imag"
, ReturnType (bare_array_type (complex_to_real t, i))
, [bare_array_type (t, i)]
, AoS ) )
complex_types )
(List.range 0 8) ;
List.iter
~f:(fun i ->
List.iter
~f:(fun t ->
add_unqualified
( "get_real"
, ReturnType (bare_array_type (complex_to_real t, i))
, [bare_array_type (t, i)]
, AoS ) )
complex_types )
(List.range 0 8) ;
add_unqualified
("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS) ;
add_unqualified
Expand Down
12 changes: 10 additions & 2 deletions src/middle/Stan_math_signatures.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ val pretty_print_all_math_sigs : unit Fmt.t
val pretty_print_all_math_distributions : unit Fmt.t

type dimensionality

type fkind = Lpmf | Lpdf | Log | Rng | Cdf | Ccdf | UnaryVectorized
type return_behavior

type fkind = private
| Lpmf
| Lpdf
| Log
| Rng
| Cdf
| Ccdf
| UnaryVectorized of return_behavior
[@@deriving show {with_path= false}]

val distributions :
Expand Down
8 changes: 4 additions & 4 deletions test/integration/cli-args/canonicalize/canonical.expected
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ transformed data {
int a = -12;
real b = 1.5;
int c = abs(a);
real d = fabs(b);
real d = abs(b);
array[0] int x_i;
array[0] real x_r;
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
Expand Down Expand Up @@ -287,8 +287,8 @@ model {
*/
parameters {
real<lower=-1, upper=1> x_raw;
real<lower=-(1 - sqrt(1 - square(1 - fabs(x_raw)))),
upper=(1 - sqrt(1 - square(1 - fabs(x_raw))))> y_raw;
real<lower=-(1 - sqrt(1 - square(1 - abs(x_raw)))),
upper=(1 - sqrt(1 - square(1 - abs(x_raw))))> y_raw;
}
transformed parameters {
real<lower=-1, upper=1> x;
Expand All @@ -297,6 +297,6 @@ transformed parameters {
y = ((y_raw > 0) ? 1 : -1) - y_raw;
}
model {
target += log1m(sqrt(1 - square(1 - fabs(x_raw))));
target += log1m(sqrt(1 - square(1 - abs(x_raw))));
}

Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ transformed data {
int a = -12;
real b = 1.5;
int c = abs(a);
real d = fabs(b);
real d = abs(b);
array[0] int x_i;
array[0] real x_r;
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
Expand Down Expand Up @@ -266,8 +266,8 @@ model {
*/
parameters {
real<lower=-1, upper=1> x_raw;
real<lower=-(1 - sqrt(1 - square(1 - fabs(x_raw)))),
upper=(1 - sqrt(1 - square(1 - fabs(x_raw))))> y_raw;
real<lower=-(1 - sqrt(1 - square(1 - abs(x_raw)))),
upper=(1 - sqrt(1 - square(1 - abs(x_raw))))> y_raw;
}
transformed parameters {
real<lower=-1, upper=1> x;
Expand All @@ -276,6 +276,6 @@ transformed parameters {
y = ((y_raw > 0) ? 1 : -1) - y_raw;
}
model {
target += log1m(sqrt(1 - square(1 - fabs(x_raw))));
target += log1m(sqrt(1 - square(1 - abs(x_raw))));
}

Loading