5
5
6
6
7
7
class BayesianNet (object ):
8
- def __init__ (self , names , edges , tables = None ):
9
- self .n_nodes = len (names )
10
- if tables is None : tables = [[0 ]] * self .n_nodes
11
- self .nodes = [{'name' : name , 'table' : np .array (table )} for name , table in zip (names , tables )]
12
- self .name2idx = {k : v for v , k in enumerate (names )}
13
- self .graph = np .zeros ((self .n_nodes , self .n_nodes ))
14
- for edge in edges :
15
- self .graph [self .name2idx [edge [1 ]], self .name2idx [edge [0 ]]] = 1
16
- self .binary = np .array ([1 << self .n_nodes - 1 - i for i in range (self .n_nodes )])
17
8
18
- def fit (self , data ):
19
- data_size = len (data )
20
- for i , node in enumerate ( self . nodes ) :
21
- table = []
22
- parents = self .graph [ i ] == 1
23
- marginal = data [:, parents ]
24
- index = np . zeros ( data . shape [ 0 ])
25
- if marginal . shape [ 1 ] > 0 :
26
- index = ( marginal * self . binary [ - marginal . shape [ 1 ]:]). sum ( axis = 1 )
27
- for j in range ( 2 ** parents . sum ()):
28
- table . append ( data [( index == j ), i ]. sum () / ( index == j ). sum ())
29
- node [ 'table' ] = np . array ( table )
9
+ def __init__ (self , names , edges , tables = None ):
10
+ self . n_nodes = len (names )
11
+ if tables is None :
12
+ tables = [[ 0 ]] * self . n_nodes
13
+ self .nodes = [{ 'name' : name , 'table' : np . array (
14
+ table )} for name , table in zip ( names , tables ) ]
15
+ self . name2idx = { k : v for v , k in enumerate ( names )}
16
+ self . graph = np . zeros (( self . n_nodes , self . n_nodes ))
17
+ for edge in edges :
18
+ self . graph [ self . name2idx [ edge [ 1 ]], self . name2idx [ edge [ 0 ]]] = 1
19
+ self . binary = np . array (
20
+ [ 1 << self . n_nodes - 1 - i for i in range ( self . n_nodes )] )
30
21
31
- def joint_p (self , values ):
32
- p = 1
33
- for i in range (self .n_nodes ):
34
- index = 0
35
- parents = self .graph [i ]== 1
36
- if parents .sum () > 0 :
37
- index = np .dot (values [parents ], self .binary [- parents .sum ():])
38
- p *= (1 - values [i ]) + (2 * values [i ] - 1 ) * self .nodes [i ]['table' ][int (index )]
39
- return p
22
+ def fit (self , data ):
23
+ data_size = len (data )
24
+ for i , node in enumerate (self .nodes ):
25
+ table = []
26
+ parents = self .graph [i ] == 1
27
+ marginal = data [:, parents ]
28
+ index = np .zeros (data .shape [0 ])
29
+ if marginal .shape [1 ] > 0 :
30
+ index = (
31
+ marginal * self .binary [- marginal .shape [1 ]:]).sum (axis = 1 )
32
+ for j in range (2 ** parents .sum ()):
33
+ table .append (data [(index == j ), i ].sum () / (index == j ).sum ())
34
+ node ['table' ] = np .array (table )
40
35
41
- def marginal_p (self , condition ):
42
- p = 0
43
- values = - np .ones (self .n_nodes )
44
- for v in condition :
45
- values [self .name2idx [v [1 ]]] = int (v [0 ] != '~' )
46
- mask = np .arange (self .n_nodes )[(values == - 1 )]
47
- n_unkowns = self .n_nodes - len (condition )
48
- for i in range (2 ** n_unkowns ):
49
- values [mask ] = np .array ([int (x ) for x in '{:0{size}b}' .format (i , size = n_unkowns )])
50
- p += self .joint_p (values )
51
- return p
36
+ def joint_p (self , values ):
37
+ p = 1
38
+ for i in range (self .n_nodes ):
39
+ index = 0
40
+ parents = self .graph [i ] == 1
41
+ if parents .sum () > 0 :
42
+ index = np .dot (values [parents ], self .binary [- parents .sum ():])
43
+ p *= (1 - values [i ]) + (2 * values [i ] - 1 ) * \
44
+ self .nodes [i ]['table' ][int (index )]
45
+ return p
46
+
47
+ def marginal_p (self , condition ):
48
+ p = 0
49
+ values = - np .ones (self .n_nodes )
50
+ for v in condition :
51
+ values [self .name2idx [v [1 ]]] = int (v [0 ] != '~' )
52
+ mask = np .arange (self .n_nodes )[(values == - 1 )]
53
+ n_unkowns = self .n_nodes - len (condition )
54
+ for i in range (2 ** n_unkowns ):
55
+ values [mask ] = np .array (
56
+ [int (x ) for x in '{:0{size}b}' .format (i , size = n_unkowns )])
57
+ p += self .joint_p (values )
58
+ return p
59
+
60
+ def query (self , v , condition ):
61
+ p_pos = self .marginal_p ([f'+{ v } ' ] + condition ) / self .marginal_p (condition )
62
+ return [1 - p_pos , p_pos ]
52
63
53
- def query (self , v , condition ):
54
- p_pos = self .marginal_p ([f'+{ v } ' ] + condition ) / self .marginal_p (condition )
55
- return [1 - p_pos , p_pos ]
56
64
57
65
def get_asia_data (url ):
58
- return read_csv (url ).apply (lambda x : x == 'yes' ).astype (int ).values
66
+ return read_csv (url ).apply (lambda x : x == 'yes' ).astype (int ).values
59
67
60
68
61
69
def main ():
62
- names = 'ATSLBEXD'
63
- edges = ['AT' , 'SL' , 'SB' , 'TE' , 'LE' , 'BD' , 'EX' , 'ED' ]
64
- #tables = [[0.01], [0.01, 0.05], [0.5], [0.01, 0.1], [0.3, 0.6], [0, 1, 1, 1], [0.05, 0.98], [0.1, 0.7, 0.8, 0.9]]
65
- bn = BayesianNet (list (names ), edges ) # also can use predefined conditional tables
66
- asia_url = 'http://www.ccd.pitt.edu/wiki/images/ASIA10k.csv'
67
- bn .fit (get_asia_data (asia_url ))
68
- print (bn .nodes )
69
- for condition in [[], ['+A' , '~S' ], ['+A' , '~S' , '~D' , '+X' ]]:
70
- for c in ['T' , 'L' , 'B' , 'E' ]:
71
- print ('p({}|{})={}' .format (c , ',' .join (condition ), bn .query (c , condition )))
70
+ names = 'ATSLBEXD'
71
+ edges = ['AT' , 'SL' , 'SB' , 'TE' , 'LE' , 'BD' , 'EX' , 'ED' ]
72
+ #tables = [[0.01], [0.01, 0.05], [0.5], [0.01, 0.1], [0.3, 0.6], [0, 1, 1, 1], [0.05, 0.98], [0.1, 0.7, 0.8, 0.9]]
73
+ # also can use predefined conditional tables
74
+ bn = BayesianNet (list (names ), edges )
75
+ asia_url = 'http://www.ccd.pitt.edu/wiki/images/ASIA10k.csv'
76
+ bn .fit (get_asia_data (asia_url ))
77
+ print (bn .nodes )
78
+ for condition in [[], ['+A' , '~S' ], ['+A' , '~S' , '~D' , '+X' ]]:
79
+ for c in ['T' , 'L' , 'B' , 'E' ]:
80
+ print ('p({}|{})={}' .format (c , ',' .join (
81
+ condition ), bn .query (c , condition )))
72
82
73
83
74
84
if __name__ == "__main__" :
75
- main ()
85
+ main ()
0 commit comments