@@ -37,6 +37,19 @@ type dimensionality =
37
37
let rec bare_array_type (t , i ) =
38
38
match i with 0 -> t | j -> UnsizedType. UArray (bare_array_type (t, j - 1 ))
39
39
40
+ let bare_types =
41
+ [ UnsizedType. UInt ; UReal ; UComplex ; UVector ; URowVector ; UMatrix
42
+ ; UComplexVector ; UComplexRowVector ; UComplexMatrix ]
43
+
44
+ let vector_types = [UnsizedType. UReal ; UArray UReal ; UVector ; URowVector ]
45
+ let primitive_types = [UnsizedType. UInt ; UReal ]
46
+
47
+ let complex_types =
48
+ [UnsizedType. UComplex ; UComplexVector ; UComplexRowVector ; UComplexMatrix ]
49
+
50
+ let all_vector_types =
51
+ [UnsizedType. UReal ; UArray UReal ; UVector ; URowVector ; UInt ; UArray UInt ]
52
+
40
53
let rec expand_arg = function
41
54
| DInt -> [UnsizedType. UInt ]
42
55
| DReal -> [UReal ]
@@ -57,21 +70,21 @@ let rec expand_arg = function
57
70
concat_map all_base ~f: (fun a ->
58
71
map (range 0 8 ) ~f: (fun i -> bare_array_type (a, i)) ))
59
72
| DDeepComplexVectorized ->
60
- let all_base =
61
- [UnsizedType. UComplex ; UComplexRowVector ; UComplexVector ; UComplexMatrix ]
62
- in
63
73
List. (
64
- concat_map all_base ~f: (fun a ->
74
+ concat_map complex_types ~f: (fun a ->
65
75
map (range 0 8 ) ~f: (fun i -> bare_array_type (a, i)) ))
66
76
77
+ type return_behavior = SameAsArg | IntsToReals | ComplexToReals
78
+ [@@ deriving show {with_path= false }]
79
+
67
80
type fkind =
68
81
| Lpmf
69
82
| Lpdf
70
- | Log [@ printer fun fmt _ -> fprintf fmt " log (deprecated)" ]
83
+ | Log [@ printer fun fmt _ -> fprintf fmt " Log (deprecated)" ]
71
84
| Rng
72
85
| Cdf
73
86
| Ccdf
74
- | UnaryVectorized
87
+ | UnaryVectorized of return_behavior
75
88
[@@ deriving show {with_path= false }]
76
89
77
90
type fun_arg = UnsizedType .autodifftype * UnsizedType .t
@@ -199,7 +212,7 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
199
212
| Rng -> [" _rng" ]
200
213
| Cdf -> [" _cdf" ; " _cdf_log" ; " _lcdf" ]
201
214
| Ccdf -> [" _ccdf_log" ; " _lccdf" ]
202
- | UnaryVectorized -> [" " ] in
215
+ | UnaryVectorized _ -> [" " ] in
203
216
let add_ints = function DVReal -> DIntAndReals | x -> x in
204
217
let all_expanded args = all_combinations (List. map ~f: expand_arg args) in
205
218
let promoted_dim = function
@@ -208,7 +221,11 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
208
221
| _ -> UReal in
209
222
let find_rt rt args = function
210
223
| Rng -> UnsizedType. ReturnType (rng_return_type rt args)
211
- | UnaryVectorized -> ReturnType (ints_to_real (List. hd_exn args))
224
+ | UnaryVectorized SameAsArg -> ReturnType (List. hd_exn args)
225
+ | UnaryVectorized IntsToReals ->
226
+ ReturnType (ints_to_real (List. hd_exn args))
227
+ | UnaryVectorized ComplexToReals ->
228
+ ReturnType (complex_to_real (List. hd_exn args))
212
229
| _ -> ReturnType UReal in
213
230
let create_from_fk_args fk arglists =
214
231
List. concat_map arglists ~f: (fun args ->
@@ -222,7 +239,6 @@ let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
222
239
let name = name ^ " _rng" in
223
240
List. map (all_expanded args) ~f: (fun args ->
224
241
(name, find_rt rt args Rng , args, mem_pattern) )
225
- | UnaryVectorized -> create_from_fk_args UnaryVectorized (all_expanded args)
226
242
| fk -> create_from_fk_args fk (all_expanded args) in
227
243
List. concat_map fnkinds ~f: add_fnkind
228
244
|> List. filter ~f: (fun (n , _ , _ , _ ) -> not (Set. mem missing_math_functions n))
@@ -344,59 +360,64 @@ let distributions =
344
360
; ([Lpdf ], " wishart_cholesky" , [DMatrix ; DReal ; DMatrix ], SoA )
345
361
; ([Lpdf ; Log ], " wishart" , [DMatrix ; DReal ; DMatrix ], SoA ) ]
346
362
363
+ let basic_vectorized = UnaryVectorized IntsToReals
364
+
347
365
let math_sigs =
348
- [ ([UnaryVectorized ], " acos" , [DDeepVectorized ], Common.Helpers. SoA )
349
- ; ([UnaryVectorized ], " acosh" , [DDeepVectorized ], SoA )
350
- ; ([UnaryVectorized ], " asin" , [DDeepVectorized ], SoA )
351
- ; ([UnaryVectorized ], " asinh" , [DDeepVectorized ], SoA )
352
- ; ([UnaryVectorized ], " atan" , [DDeepVectorized ], SoA )
353
- ; ([UnaryVectorized ], " atanh" , [DDeepVectorized ], SoA )
354
- ; ([UnaryVectorized ], " cbrt" , [DDeepVectorized ], SoA )
355
- ; ([UnaryVectorized ], " ceil" , [DDeepVectorized ], SoA )
356
- ; ([UnaryVectorized ], " cos" , [DDeepVectorized ], SoA )
357
- ; ([UnaryVectorized ], " cosh" , [DDeepVectorized ], SoA )
358
- ; ([UnaryVectorized ], " digamma" , [DDeepVectorized ], SoA )
359
- ; ([UnaryVectorized ], " erf" , [DDeepVectorized ], SoA )
360
- ; ([UnaryVectorized ], " erfc" , [DDeepVectorized ], SoA )
361
- ; ([UnaryVectorized ], " exp" , [DDeepVectorized ], SoA )
362
- ; ([UnaryVectorized ], " exp2" , [DDeepVectorized ], SoA )
363
- ; ([UnaryVectorized ], " expm1" , [DDeepVectorized ], SoA )
364
- ; ([UnaryVectorized ], " fabs" , [DDeepVectorized ], SoA )
365
- ; ([UnaryVectorized ], " floor" , [DDeepVectorized ], SoA )
366
- ; ([UnaryVectorized ], " inv" , [DDeepVectorized ], SoA )
367
- ; ([UnaryVectorized ], " inv_cloglog" , [DDeepVectorized ], SoA )
368
- ; ([UnaryVectorized ], " inv_erfc" , [DDeepVectorized ], SoA )
369
- ; ([UnaryVectorized ], " inv_logit" , [DDeepVectorized ], SoA )
370
- ; ([UnaryVectorized ], " inv_Phi" , [DDeepVectorized ], SoA )
371
- ; ([UnaryVectorized ], " inv_sqrt" , [DDeepVectorized ], SoA )
372
- ; ([UnaryVectorized ], " inv_square" , [DDeepVectorized ], SoA )
373
- ; ([UnaryVectorized ], " lambert_w0" , [DDeepVectorized ], SoA )
374
- ; ([UnaryVectorized ], " lambert_wm1" , [DDeepVectorized ], SoA )
375
- ; ([UnaryVectorized ], " lgamma" , [DDeepVectorized ], SoA )
376
- ; ([UnaryVectorized ], " log" , [DDeepVectorized ], SoA )
377
- ; ([UnaryVectorized ], " log10" , [DDeepVectorized ], SoA )
378
- ; ([UnaryVectorized ], " log1m" , [DDeepVectorized ], SoA )
379
- ; ([UnaryVectorized ], " log1m_exp" , [DDeepVectorized ], SoA )
380
- ; ([UnaryVectorized ], " log1m_inv_logit" , [DDeepVectorized ], SoA )
381
- ; ([UnaryVectorized ], " log1p" , [DDeepVectorized ], SoA )
382
- ; ([UnaryVectorized ], " log1p_exp" , [DDeepVectorized ], SoA )
383
- ; ([UnaryVectorized ], " log2" , [DDeepVectorized ], SoA )
384
- ; ([UnaryVectorized ], " log_inv_logit" , [DDeepVectorized ], SoA )
385
- ; ([UnaryVectorized ], " logit" , [DDeepVectorized ], SoA )
386
- ; ([UnaryVectorized ], " Phi" , [DDeepVectorized ], SoA )
387
- ; ([UnaryVectorized ], " Phi_approx" , [DDeepVectorized ], SoA )
388
- ; ([UnaryVectorized ], " round" , [DDeepVectorized ], SoA )
389
- ; ([UnaryVectorized ], " sin" , [DDeepVectorized ], SoA )
390
- ; ([UnaryVectorized ], " sinh" , [DDeepVectorized ], SoA )
391
- ; ([UnaryVectorized ], " sqrt" , [DDeepVectorized ], SoA )
392
- ; ([UnaryVectorized ], " square" , [DDeepVectorized ], SoA )
393
- ; ([UnaryVectorized ], " step" , [DReal ], SoA )
394
- ; ([UnaryVectorized ], " tan" , [DDeepVectorized ], SoA )
395
- ; ([UnaryVectorized ], " tanh" , [DDeepVectorized ], SoA )
396
- (* ; add_nullary ("target") *)
397
- ; ([UnaryVectorized ], " tgamma" , [DDeepVectorized ], SoA )
398
- ; ([UnaryVectorized ], " trunc" , [DDeepVectorized ], SoA )
399
- ; ([UnaryVectorized ], " trigamma" , [DDeepVectorized ], SoA ) ]
366
+ [ ([basic_vectorized], " acos" , [DDeepVectorized ], Common.Helpers. SoA )
367
+ ; ([basic_vectorized], " acosh" , [DDeepVectorized ], SoA )
368
+ ; ([basic_vectorized], " asin" , [DDeepVectorized ], SoA )
369
+ ; ([basic_vectorized], " asinh" , [DDeepVectorized ], SoA )
370
+ ; ([basic_vectorized], " atan" , [DDeepVectorized ], SoA )
371
+ ; ([basic_vectorized], " atanh" , [DDeepVectorized ], SoA )
372
+ ; ([basic_vectorized], " cbrt" , [DDeepVectorized ], SoA )
373
+ ; ([basic_vectorized], " ceil" , [DDeepVectorized ], SoA )
374
+ ; ([basic_vectorized], " cos" , [DDeepVectorized ], SoA )
375
+ ; ([basic_vectorized], " cosh" , [DDeepVectorized ], SoA )
376
+ ; ([basic_vectorized], " digamma" , [DDeepVectorized ], SoA )
377
+ ; ([basic_vectorized], " erf" , [DDeepVectorized ], SoA )
378
+ ; ([basic_vectorized], " erfc" , [DDeepVectorized ], SoA )
379
+ ; ([basic_vectorized], " exp" , [DDeepVectorized ], SoA )
380
+ ; ([basic_vectorized], " exp2" , [DDeepVectorized ], SoA )
381
+ ; ([basic_vectorized], " expm1" , [DDeepVectorized ], SoA )
382
+ ; ([basic_vectorized], " fabs" , [DDeepVectorized ], SoA )
383
+ ; ([UnaryVectorized ComplexToReals ], " get_imag" , [DDeepComplexVectorized ], AoS )
384
+ ; ([UnaryVectorized ComplexToReals ], " get_real" , [DDeepComplexVectorized ], AoS )
385
+ ; ([UnaryVectorized SameAsArg ], " abs" , [DDeepVectorized ], SoA )
386
+ ; ([UnaryVectorized ComplexToReals ], " abs" , [DDeepComplexVectorized ], AoS )
387
+ ; ([basic_vectorized], " floor" , [DDeepVectorized ], SoA )
388
+ ; ([basic_vectorized], " inv" , [DDeepVectorized ], SoA )
389
+ ; ([basic_vectorized], " inv_cloglog" , [DDeepVectorized ], SoA )
390
+ ; ([basic_vectorized], " inv_erfc" , [DDeepVectorized ], SoA )
391
+ ; ([basic_vectorized], " inv_logit" , [DDeepVectorized ], SoA )
392
+ ; ([basic_vectorized], " inv_Phi" , [DDeepVectorized ], SoA )
393
+ ; ([basic_vectorized], " inv_sqrt" , [DDeepVectorized ], SoA )
394
+ ; ([basic_vectorized], " inv_square" , [DDeepVectorized ], SoA )
395
+ ; ([basic_vectorized], " lambert_w0" , [DDeepVectorized ], SoA )
396
+ ; ([basic_vectorized], " lambert_wm1" , [DDeepVectorized ], SoA )
397
+ ; ([basic_vectorized], " lgamma" , [DDeepVectorized ], SoA )
398
+ ; ([basic_vectorized], " log" , [DDeepVectorized ], SoA )
399
+ ; ([basic_vectorized], " log10" , [DDeepVectorized ], SoA )
400
+ ; ([basic_vectorized], " log1m" , [DDeepVectorized ], SoA )
401
+ ; ([basic_vectorized], " log1m_exp" , [DDeepVectorized ], SoA )
402
+ ; ([basic_vectorized], " log1m_inv_logit" , [DDeepVectorized ], SoA )
403
+ ; ([basic_vectorized], " log1p" , [DDeepVectorized ], SoA )
404
+ ; ([basic_vectorized], " log1p_exp" , [DDeepVectorized ], SoA )
405
+ ; ([basic_vectorized], " log2" , [DDeepVectorized ], SoA )
406
+ ; ([basic_vectorized], " log_inv_logit" , [DDeepVectorized ], SoA )
407
+ ; ([basic_vectorized], " logit" , [DDeepVectorized ], SoA )
408
+ ; ([basic_vectorized], " Phi" , [DDeepVectorized ], SoA )
409
+ ; ([basic_vectorized], " Phi_approx" , [DDeepVectorized ], SoA )
410
+ ; ([basic_vectorized], " round" , [DDeepVectorized ], SoA )
411
+ ; ([basic_vectorized], " sin" , [DDeepVectorized ], SoA )
412
+ ; ([basic_vectorized], " sinh" , [DDeepVectorized ], SoA )
413
+ ; ([basic_vectorized], " sqrt" , [DDeepVectorized ], SoA )
414
+ ; ([basic_vectorized], " square" , [DDeepVectorized ], SoA )
415
+ ; ([basic_vectorized], " step" , [DReal ], SoA )
416
+ ; ([basic_vectorized], " tan" , [DDeepVectorized ], SoA )
417
+ ; ([basic_vectorized], " tanh" , [DDeepVectorized ], SoA )
418
+ ; ([basic_vectorized], " tgamma" , [DDeepVectorized ], SoA )
419
+ ; ([basic_vectorized], " trunc" , [DDeepVectorized ], SoA )
420
+ ; ([basic_vectorized], " trigamma" , [DDeepVectorized ], SoA ) ]
400
421
401
422
let all_declarative_sigs = distributions @ math_sigs
402
423
@@ -535,18 +556,6 @@ let pretty_print_math_lib_operator_sigs op =
535
556
else operator_to_stan_math_fns op |> List. map ~f: pretty_print_math_sigs
536
557
537
558
(* -- Some helper definitions to populate stan_math_signatures -- *)
538
- let bare_types =
539
- [ UnsizedType. UInt ; UReal ; UComplex ; UVector ; URowVector ; UMatrix
540
- ; UComplexVector ; UComplexRowVector ; UComplexMatrix ]
541
-
542
- let vector_types = [UnsizedType. UReal ; UArray UReal ; UVector ; URowVector ]
543
- let primitive_types = [UnsizedType. UInt ; UReal ]
544
-
545
- let complex_types =
546
- [UnsizedType. UComplex ; UComplexVector ; UComplexRowVector ; UComplexMatrix ]
547
-
548
- let all_vector_types =
549
- [UnsizedType. UReal ; UArray UReal ; UVector ; URowVector ; UInt ; UArray UInt ]
550
559
551
560
let add_qualified (name , rt , argts , supports_soa ) =
552
561
Hashtbl. add_multi stan_math_signatures ~key: name
@@ -863,9 +872,6 @@ let for_vector_types s = List.iter ~f:s vector_types
863
872
let () =
864
873
List. iter declarative_fnsigs ~f: (fun (key , rt , args , mem_pattern ) ->
865
874
Hashtbl. add_multi stan_math_signatures ~key ~data: (rt, args, mem_pattern) ) ;
866
- add_unqualified (" abs" , ReturnType UInt , [UInt ], SoA ) ;
867
- add_unqualified (" abs" , ReturnType UReal , [UReal ], SoA ) ;
868
- add_unqualified (" abs" , ReturnType UReal , [UComplex ], AoS ) ;
869
875
add_unqualified (" acos" , ReturnType UComplex , [UComplex ], AoS ) ;
870
876
add_unqualified (" acosh" , ReturnType UComplex , [UComplex ], AoS ) ;
871
877
List. iter
@@ -1292,28 +1298,6 @@ let () =
1292
1298
, ReturnType UReal
1293
1299
, [UMatrix ; UMatrix ; UMatrix ; UVector ; UMatrix ; UVector ; UMatrix ]
1294
1300
, AoS ) ;
1295
- List. iter
1296
- ~f: (fun i ->
1297
- List. iter
1298
- ~f: (fun t ->
1299
- add_unqualified
1300
- ( " get_imag"
1301
- , ReturnType (bare_array_type (complex_to_real t, i))
1302
- , [bare_array_type (t, i)]
1303
- , AoS ) )
1304
- complex_types )
1305
- (List. range 0 8 ) ;
1306
- List. iter
1307
- ~f: (fun i ->
1308
- List. iter
1309
- ~f: (fun t ->
1310
- add_unqualified
1311
- ( " get_real"
1312
- , ReturnType (bare_array_type (complex_to_real t, i))
1313
- , [bare_array_type (t, i)]
1314
- , AoS ) )
1315
- complex_types )
1316
- (List. range 0 8 ) ;
1317
1301
add_unqualified
1318
1302
(" gp_dot_prod_cov" , ReturnType UMatrix , [UArray UReal ; UReal ], AoS ) ;
1319
1303
add_unqualified
0 commit comments