-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdiscreteprobability.go
153 lines (126 loc) · 4.02 KB
/
discreteprobability.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
// Written in 2019 by Peter Li
// Package discreteprobability is to generate random values with a corresponding weights.
// Example usage:
//
// values := []int{1, 2, 3}
// weights := []float64{0.25, 0.5, 0.25}
// generator, err := discreteprobability.New(values, weights)
// if err != nil {
// panic(err) // Error handlers
// }
// num := generator.RandomInt()
//
// The num would have a 50% probability value of 2, a 25% probability value of 1 or 3.
package discreteprobability
import (
"errors"
"math/rand"
"reflect"
"sort"
"time"
)
// ErrType is returned when type assertion is failed
var ErrType = errors.New("failed to assert type")
// ErrNotSlice is returned when the type of value is not a slice
var ErrNotSlice = errors.New("value is not a slice")
// ErrLength is returned when the length of values and weights are different
var ErrLength = errors.New("length of values and weights not match")
// ErrWeightSum is returned when the sum of weights is not 1
var ErrWeightSum = errors.New("")
var seed = time.Now().UnixNano()
// Generator is the struct to store the sorted values and weights
// and can generate random values which based on the corresponding weight
type Generator struct {
values []reflect.Value
weights []float64
size int
source rand.Source
}
func (g *Generator) Len() int { return len(g.values) }
func (g *Generator) Swap(i, j int) {
g.values[i], g.values[j] = g.values[j], g.values[i]
g.weights[i], g.weights[j] = g.weights[j], g.weights[i]
}
func (g *Generator) Less(i, j int) bool { return g.weights[i] < g.weights[j] }
// New returns a new Generator. It will return error if values and weights have different length
// or the sum of weights not equal to 1
func New(v interface{}, w []float64) (*Generator, error) {
t := reflect.TypeOf(v).Kind()
if t != reflect.Slice {
return nil, ErrNotSlice
}
val := reflect.ValueOf(v)
values := make([]reflect.Value, val.Len())
for i := 0; i < val.Len(); i++ {
values[i] = val.Index(i)
}
if len(values) != len(w) {
return nil, ErrLength
}
s := &Generator{
values: values,
weights: w,
size: len(values),
source: rand.NewSource(seed),
}
sort.Sort(s)
sum := float64(0)
for i, weight := range s.weights {
sum += weight
s.weights[i] = sum
}
if sum - 1 > 1e-4 {
return nil, ErrWeightSum
}
return s, nil
}
// SetSeed is to set a custom random seed other than the time stamp.
func (g *Generator) SetSeed(s int64) {
g.source = rand.NewSource(s)
}
func (g *Generator) random() reflect.Value {
f := float64(g.source.Int63()) / (1 << 63)
i := sort.Search(g.size, func(i int) bool {
return g.weights[i] >= f
})
return g.values[i]
}
// RandomInt returns the int value from the value set with corresponding weights without type assertion.
// Will panic if input value is not ([]int, []float64)
func (g *Generator) RandomInt() int {
return int(g.random().Int())
}
// RandomFloat64 returns the float64 value from the value set with corresponding weights without type assertion.
// Will panic if input value is not ([]float64, []float64)
func (g *Generator) RandomFloat64() float64 {
return g.random().Float()
}
// RandomString returns the string value from the value set with corresponding weights without type assertion.
// The input value should be ([]string, []float64)
func (g *Generator) RandomString() string {
return g.random().String()
}
// RandomIntSafe returns the int value from the value set with corresponding weights.
func (g *Generator) RandomIntSafe() (int, error) {
r, ok := g.random().Interface().(int)
if !ok {
return r, ErrType
}
return r, nil
}
// RandomStringSafe returns the int value from the value set with corresponding weights.
func (g *Generator) RandomStringSafe() (string, error) {
r, ok := g.random().Interface().(string)
if !ok {
return r, ErrType
}
return r, nil
}
// RandomFloat64Safe returns the int value from the value set with corresponding weights.
func (g *Generator) RandomFloat64Safe() (float64, error) {
r, ok := g.random().Interface().(float64)
if !ok {
return r, ErrType
}
return r, nil
}