Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 181 additions & 29 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
"""
BayesianNetwork
module BayesianNetworkModule

A structure representing a Bayesian Network.
"""
struct BayesianNetwork{V,T,F}
using Graphs
using Distributions

###############################################################################
# 1) BayesianNetwork definition (mutable + node_types)
###############################################################################

mutable struct BayesianNetwork{V,T,F}
graph::SimpleDiGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
names_to_ids::Dict{V,T}
"values of each variable in the network"
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future
"distributions of the stochastic variables"
distributions::Vector{Distribution}
"deterministic functions of the deterministic variables"
deterministic_functions::Vector{F}
"ids of the stochastic variables"
values::Dict{V,Any}
distributions::Vector{Any} # Distribution or function returning a Distribution
deterministic_functions::Vector{F} # (unused here)
stochastic_ids::Vector{T}
"ids of the deterministic variables"
deterministic_ids::Vector{T}
is_stochastic::BitVector
is_observed::BitVector
node_types::Vector{Symbol} # e.g. :discrete or :continuous
end

"""
Construct an empty BayesianNetwork with symbol names.
"""
function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleDiGraph{Int}(), # by default, vertex ids are integers
SimpleDiGraph{Int}(),
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Distribution[],
Any[],
Any[],
Int[],
Int[],
BitVector(),
BitVector(),
Symbol[],
)
end

###############################################################################
# 2) Graph Helpers
###############################################################################

"""
condition(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V}

Expand Down Expand Up @@ -108,13 +114,17 @@ function decondition!(bn::BayesianNetwork{V}, deconditioning_variables::Vector{V
end

"""
add_stochastic_vertex!(bn::BayesianNetwork{V}, name::V, dist::Distribution, is_observed::Bool) where {V}

Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesian Network. Returns the id of the added vertex
if successful, 0 otherwise.
Add a stochastic vertex to the BayesianNetwork.
- `dist` can be a `Distribution` or a function returning a `Distribution`.
- `node_type` can be `:discrete` or `:continuous`.
- `is_observed` defaults to `false`.
"""
function add_stochastic_vertex!(
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool
bn::BayesianNetwork{V,T},
name::V,
dist::Any,
node_type::Symbol=:continuous;
is_observed::Bool=false,
)::T where {V,T}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
Expand All @@ -124,14 +134,13 @@ function add_stochastic_vertex!(
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.stochastic_ids, id)
push!(bn.node_types, node_type)
return id
end

"""
add_deterministic_vertex!(bn::BayesianNetwork{V}, name::V, f::F) where {V,F}

Adds a deterministic vertex with name `name` and deterministic function `f` to the Bayesian Network. Returns the id of the added vertex
if successful, 0 otherwise.
Add a deterministic vertex to the BayesianNetwork.
- `f` is a function that defines how this node is computed from its parents.
"""
function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T where {T,V,F}
Graphs.add_vertex!(bn.graph) || return 0
Expand All @@ -142,13 +151,12 @@ function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T w
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.deterministic_ids, id)
push!(bn.node_types, :deterministic)
return id
end

"""
add_edge!(bn::BayesianNetwork{V}, from::V, to::V) where {V}

Adds an edge between two vertices in the Bayesian Network. Returns true if successful, false otherwise.
Add a directed edge `from -> to` in the BayesianNetwork's graph.
"""
function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V}
from_id = bn.names_to_ids[from]
Expand Down Expand Up @@ -329,3 +337,147 @@ function is_conditionally_independent(
) where {V}
return is_conditionally_independent(bn, [X], [Y], Z)
end

###############################################################################
# 3) Parent/Distribution Helpers
###############################################################################

function parent_ids(bn::BayesianNetwork, node_id::Int)
return inneighbors(bn.graph, node_id)
end

function parent_values(bn::BayesianNetwork, node_id::Int)
pids = parent_ids(bn, node_id)
sort!(pids)
vals = Any[]
for pid in pids
varname = bn.names[pid]
if !haskey(bn.values, varname)
error("Missing value for parent $varname of node id=$node_id")
end
push!(vals, bn.values[varname])
end
return vals
end

function get_distribution(bn::BayesianNetwork, node_id::Int)::Distribution
stored = bn.distributions[node_id]
if stored isa Distribution
return stored
elseif stored isa Function
pvals = parent_values(bn, node_id) # gather parents' assigned values
return stored(pvals...)
else
error("Node $node_id has invalid distribution entry.")
end
end

function is_discrete_node(bn::BayesianNetwork, node_id::Int)
return bn.node_types[node_id] == :discrete
end

###############################################################################
# 4) Logpdf Computation
###############################################################################

"""
Compute the sum of log probabilities for all **stochastic** nodes
using the current values assigned in `bn.values`.
If any distribution or parent's value is missing or invalid, returns -Inf.
"""
function compute_full_logpdf(bn::BayesianNetwork)
logp = 0.0
for sid in bn.stochastic_ids
varname = bn.names[sid]
if haskey(bn.values, varname)
# ensure parents assigned
for pid in parent_ids(bn, sid)
if !haskey(bn.values, bn.names[pid])
return -Inf
end
end
dist = get_distribution(bn, sid)
val = bn.values[varname]
lpdf = logpdf(dist, val)
if isinf(lpdf)
return -Inf
end
logp += lpdf
end
end
return logp
end

###############################################################################
# 5) Naive Summation of Discrete Configurations
###############################################################################

"""
Naive recursion:
Enumerate all discrete node values for unobserved discrete nodes.
Returns a *probability sum*, i.e. sum over exp(logpdf).
"""
function sum_discrete_configurations(
bn::BayesianNetwork, discrete_ids::Vector{Int}, idx::Int
)::Float64
if idx > length(discrete_ids)
return exp(compute_full_logpdf(bn))
else
node_id = discrete_ids[idx]
dist = get_distribution(bn, node_id)
total_prob = 0.0
for val in support(dist)
bn.values[bn.names[node_id]] = val
total_prob +=
sum_discrete_configurations(bn, discrete_ids, idx + 1) * pdf(dist, val)
end
delete!(bn.values, bn.names[node_id])
return total_prob
end
end

###############################################################################
# 6) create_log_posterior (Naive Only)
###############################################################################

"""
Creates a log_posterior function that merges unobserved values + sums out
unobserved discrete nodes (naive recursion).
Returns log(prob_sum).
"""
function create_log_posterior(bn::BayesianNetwork)
function log_posterior(unobserved_values::Dict{Symbol,Float64})
old_values = copy(bn.values)
try
# Merge unobserved
for (k, v) in unobserved_values
bn.values[k] = v
end

# Identify unobserved discrete IDs
unobs_discrete_ids = Int[]
for sid in bn.stochastic_ids
if !bn.is_observed[sid]
varname = bn.names[sid]
if !haskey(bn.values, varname) && is_discrete_node(bn, sid)
push!(unobs_discrete_ids, sid)
end
end
end

if isempty(unobs_discrete_ids)
# no discrete marginalization => direct logpdf
return compute_full_logpdf(bn)
else
# naive recursion
prob_sum = sum_discrete_configurations(bn, unobs_discrete_ids, 1)
return log(prob_sum)
end
finally
bn.values = old_values
end
end
return log_posterior
end

end # module
1 change: 0 additions & 1 deletion test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
add_edge!(bn, :B, :A)
add_edge!(bn, :B, :C)

@test !is_conditionally_independent(bn, :A, :C, Symbol[])
@test is_conditionally_independent(bn, :A, :C, [:B])
end

Expand Down
Loading