@@ -13,7 +13,8 @@ use num_traits::Float;
13
13
use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
14
14
use rand:: Rng ;
15
15
use core:: fmt;
16
- use alloc:: { boxed:: Box , vec, vec:: Vec } ;
16
+ #[ cfg( feature = "serde_with" ) ]
17
+ use serde_with:: serde_as;
17
18
18
19
/// The Dirichlet distribution `Dirichlet(alpha)`.
19
20
///
@@ -27,22 +28,23 @@ use alloc::{boxed::Box, vec, vec::Vec};
27
28
/// use rand::prelude::*;
28
29
/// use rand_distr::Dirichlet;
29
30
///
30
- /// let dirichlet = Dirichlet::new(& [1.0, 2.0, 3.0]).unwrap();
31
+ /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
31
32
/// let samples = dirichlet.sample(&mut rand::thread_rng());
32
33
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
33
34
/// ```
34
35
#[ cfg_attr( doc_cfg, doc( cfg( feature = "alloc" ) ) ) ]
36
+ #[ cfg_attr( feature = "serde_with" , serde_as) ]
35
37
#[ derive( Clone , Debug , PartialEq ) ]
36
- #[ cfg_attr( feature = "serde1" , derive( serde:: Serialize , serde:: Deserialize ) ) ]
37
- pub struct Dirichlet < F >
38
+ pub struct Dirichlet < F , const N : usize >
38
39
where
39
40
F : Float ,
40
41
StandardNormal : Distribution < F > ,
41
42
Exp1 : Distribution < F > ,
42
43
Open01 : Distribution < F > ,
43
44
{
44
45
/// Concentration parameters (alpha)
45
- alpha : Box < [ F ] > ,
46
+ #[ cfg_attr( feature = "serde_with" , serde_as( as = "[_; N]" ) ) ]
47
+ alpha : [ F ; N ] ,
46
48
}
47
49
48
50
/// Error type returned from `Dirchlet::new`.
@@ -72,7 +74,7 @@ impl fmt::Display for Error {
72
74
#[ cfg_attr( doc_cfg, doc( cfg( feature = "std" ) ) ) ]
73
75
impl std:: error:: Error for Error { }
74
76
75
- impl < F > Dirichlet < F >
77
+ impl < F , const N : usize > Dirichlet < F , N >
76
78
where
77
79
F : Float ,
78
80
StandardNormal : Distribution < F > ,
83
85
///
84
86
/// Requires `alpha.len() >= 2`.
85
87
#[ inline]
86
- pub fn new ( alpha : & [ F ] ) -> Result < Dirichlet < F > , Error > {
87
- if alpha . len ( ) < 2 {
88
+ pub fn new ( alpha : [ F ; N ] ) -> Result < Dirichlet < F , N > , Error > {
89
+ if N < 2 {
88
90
return Err ( Error :: AlphaTooShort ) ;
89
91
}
90
92
for & ai in alpha. iter ( ) {
@@ -93,36 +95,19 @@ where
93
95
}
94
96
}
95
97
96
- Ok ( Dirichlet { alpha : alpha. to_vec ( ) . into_boxed_slice ( ) } )
97
- }
98
-
99
- /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
100
- ///
101
- /// Requires `size >= 2`.
102
- #[ inline]
103
- pub fn new_with_size ( alpha : F , size : usize ) -> Result < Dirichlet < F > , Error > {
104
- if !( alpha > F :: zero ( ) ) {
105
- return Err ( Error :: AlphaTooSmall ) ;
106
- }
107
- if size < 2 {
108
- return Err ( Error :: SizeTooSmall ) ;
109
- }
110
- Ok ( Dirichlet {
111
- alpha : vec ! [ alpha; size] . into_boxed_slice ( ) ,
112
- } )
98
+ Ok ( Dirichlet { alpha } )
113
99
}
114
100
}
115
101
116
- impl < F > Distribution < Vec < F > > for Dirichlet < F >
102
+ impl < F , const N : usize > Distribution < [ F ; N ] > for Dirichlet < F , N >
117
103
where
118
104
F : Float ,
119
105
StandardNormal : Distribution < F > ,
120
106
Exp1 : Distribution < F > ,
121
107
Open01 : Distribution < F > ,
122
108
{
123
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
124
- let n = self . alpha . len ( ) ;
125
- let mut samples = vec ! [ F :: zero( ) ; n] ;
109
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> [ F ; N ] {
110
+ let mut samples = [ F :: zero ( ) ; N ] ;
126
111
let mut sum = F :: zero ( ) ;
127
112
128
113
for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
@@ -140,27 +125,12 @@ where
140
125
141
126
#[ cfg( test) ]
142
127
mod test {
128
+ use alloc:: vec:: Vec ;
143
129
use super :: * ;
144
130
145
131
#[ test]
146
132
fn test_dirichlet ( ) {
147
- let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
148
- let mut rng = crate :: test:: rng ( 221 ) ;
149
- let samples = d. sample ( & mut rng) ;
150
- let _: Vec < f64 > = samples
151
- . into_iter ( )
152
- . map ( |x| {
153
- assert ! ( x > 0.0 ) ;
154
- x
155
- } )
156
- . collect ( ) ;
157
- }
158
-
159
- #[ test]
160
- fn test_dirichlet_with_param ( ) {
161
- let alpha = 0.5f64 ;
162
- let size = 2 ;
163
- let d = Dirichlet :: new_with_size ( alpha, size) . unwrap ( ) ;
133
+ let d = Dirichlet :: new ( [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
164
134
let mut rng = crate :: test:: rng ( 221 ) ;
165
135
let samples = d. sample ( & mut rng) ;
166
136
let _: Vec < f64 > = samples
@@ -175,17 +145,17 @@ mod test {
175
145
#[ test]
176
146
#[ should_panic]
177
147
fn test_dirichlet_invalid_length ( ) {
178
- Dirichlet :: new_with_size ( 0.5f64 , 1 ) . unwrap ( ) ;
148
+ Dirichlet :: new ( [ 0.5 ] ) . unwrap ( ) ;
179
149
}
180
150
181
151
#[ test]
182
152
#[ should_panic]
183
153
fn test_dirichlet_invalid_alpha ( ) {
184
- Dirichlet :: new_with_size ( 0.0f64 , 2 ) . unwrap ( ) ;
154
+ Dirichlet :: new ( [ 0.1 , 0.0 , 0.3 ] ) . unwrap ( ) ;
185
155
}
186
156
187
157
#[ test]
188
158
fn dirichlet_distributions_can_be_compared ( ) {
189
- assert_eq ! ( Dirichlet :: new( & [ 1.0 , 2.0 ] ) , Dirichlet :: new( & [ 1.0 , 2.0 ] ) ) ;
159
+ assert_eq ! ( Dirichlet :: new( [ 1.0 , 2.0 ] ) , Dirichlet :: new( [ 1.0 , 2.0 ] ) ) ;
190
160
}
191
161
}
0 commit comments