Skip to content

Evaluation Function #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cb5c3f4
translating BUGSgraph to bayesnet definition
naseweisssss Jan 25, 2025
6d8bedb
test commit
naseweisssss Jan 26, 2025
df0fcbb
i still cant solve the error
naseweisssss Jan 27, 2025
466dc9f
Merge branch 'master' of github.com:TuringLang/JuliaBUGS.jl into ryli…
naseweisssss Jan 30, 2025
f79ddd6
Merge branch 'master' of github.com:TuringLang/JuliaBUGS.jl into ryli…
naseweisssss Jan 31, 2025
2e3ac73
translation funciton with test
naseweisssss Feb 2, 2025
15fa390
Update src/graphs.jl
naseweisssss Feb 2, 2025
84ec191
add handlement for determinisitic
naseweisssss Feb 2, 2025
2239db7
lintiing
naseweisssss Feb 2, 2025
9780e3f
minor updates
sunxd3 Feb 3, 2025
d1da301
remove redundant definiteion for NodeInfo
naseweisssss Feb 4, 2025
69430a5
Functions
naseweisssss Feb 4, 2025
561a282
remove print
naseweisssss Feb 4, 2025
15f519c
delete temp
naseweisssss Feb 4, 2025
c34d131
remove import
naseweisssss Feb 5, 2025
84105c9
change parameter
naseweisssss Feb 5, 2025
bb2ab5d
remove varname
naseweisssss Feb 6, 2025
cd0d2d9
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Feb 6, 2025
58a45e9
leave it to be undef
naseweisssss Feb 6, 2025
1745a35
add distribution and deteministic function as function
naseweisssss Feb 7, 2025
b8d1a9d
adding values to be updated
naseweisssss Feb 9, 2025
380ba0d
adding inference.jl
naseweisssss Feb 10, 2025
dc33d21
Merge branch 'rylin/eval_bayesnet' into rylin/evaluation_function_8-2…
naseweisssss Feb 10, 2025
22abb77
for the inference implementation
naseweisssss Feb 10, 2025
c1ea68e
test passing version
naseweisssss Feb 10, 2025
8b4a1fd
commit
naseweisssss Feb 11, 2025
b1af4f9
Merge branch 'master' of github.com:TuringLang/JuliaBUGS.jl into ryli…
naseweisssss Feb 11, 2025
ae45a0d
checkpoint commit
naseweisssss Feb 11, 2025
9d56049
commit for more test case
naseweisssss Feb 12, 2025
360d06c
linting to get rid of juliaformatter commetnsx
naseweisssss Feb 13, 2025
759c850
Merge branch 'master' of github.com:TuringLang/JuliaBUGS.jl into ryli…
naseweisssss Feb 18, 2025
38d9dd9
just some verifying statment
naseweisssss Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using AbstractPPL
include("bayesian_network.jl")
include("conditioning.jl")
include("functions.jl")
include("inference.jl")

export BayesianNetwork,
condition,
Expand All @@ -22,5 +23,6 @@ export BayesianNetwork,
add_deterministic_vertex!,
add_stochastic_vertex!,
add_vertex!,
translate_BUGSGraph_to_BayesianNetwork
translate_BUGSGraph_to_BayesianNetwork,
evaluate
end
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,8 @@ function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V}
to_id = bn.names_to_ids[to]
return Graphs.add_edge!(bn.graph, from_id, to_id)
end

function evaluate(bn::BayesianNetwork)
log_posterior = create_log_posterior(bn)
return log_posterior
end
181 changes: 181 additions & 0 deletions src/experimental/ProbabilisticGraphicalModels/inference.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
###############################################################################
# 4) Parent Helpers
###############################################################################

function parent_ids(bn::BayesianNetwork, node_id::Int)
# For a node_id, get all incoming edges.
return inneighbors(bn.graph, node_id)
end

function parent_values(bn::BayesianNetwork, node_id::Int)
# Retrieve the (already assigned) parent values in ascending ID order.
pids = parent_ids(bn, node_id)
sort!(pids)
vals = Any[]
for pid in pids
varname = bn.names[pid]
if !haskey(bn.values, varname)
error("Missing value for parent $varname of node id=$node_id")
end
push!(vals, bn.values[varname])
end
return vals
end

"""
Returns the Distribution object for a node, calling its stored function if needed.
"""
function get_distribution(bn::BayesianNetwork, node_id::Int)::Distribution
stored = bn.distributions[node_id]
if stored isa Distribution
return stored
elseif stored isa Function
pvals = parent_values(bn, node_id)
return stored(pvals...)
else
error(
"Node $node_id has invalid distribution entry (neither Distribution nor Function).",
)
end
end

"""
Check if a node is discrete by referencing the stored `node_types`.
"""
function is_discrete_node(bn::BayesianNetwork, node_id::Int)
return bn.node_types[node_id] == :discrete
end

###############################################################################
# 5) Summation & Log PDF Calculation
###############################################################################

function compute_full_logpdf(bn::BayesianNetwork)
println("DEBUG: compute_full_logpdf(bn) called.")
logp = 0.0
for sid in bn.stochastic_ids
varname = bn.names[sid]
# Only evaluate if this node has a value assigned
if haskey(bn.values, varname)
println("DEBUG: Node = $varname, value = $(bn.values[varname])")
# Ensure parents are assigned
for pid in parent_ids(bn, sid)
pvar = bn.names[pid]
if !haskey(bn.values, pvar)
println("DEBUG: Missing parent $pvar => returning -Inf")
return -Inf
end
end
dist = get_distribution(bn, sid)
val = bn.values[varname]
println("DEBUG: get_distribution($sid) => $dist, node value = $val")
lpdf = logpdf(dist, val)
println("DEBUG: logpdf($dist, $val) => $lpdf")
if isinf(lpdf)
println("DEBUG: logpdf is -Inf => returning -Inf")
return -Inf
end
logp += lpdf
end
end
println("DEBUG: final logp = $logp")
return logp
end

"""
Naive enumeration over all unobserved discrete nodes in `discrete_ids`.
Multiply pdf(...) for each assignment, summing up to get total probability.
"""
function sum_discrete_configurations(
bn::BayesianNetwork, discrete_ids::Vector{Int}, idx::Int
)
println("DEBUG: sum_discrete_configurations idx=$idx, discrete_ids=$discrete_ids")
if idx > length(discrete_ids)
local val = exp(compute_full_logpdf(bn))
println("DEBUG: base case => returning $val")
return val
else
node_id = discrete_ids[idx]
dist = get_distribution(bn, node_id)
println("DEBUG: Summation for node_id=$node_id => distribution=$dist")
total_prob = 0.0
for val in support(dist)
println("DEBUG: Trying val=$val for node $(bn.names[node_id])")
bn.values[bn.names[node_id]] = val
subval = sum_discrete_configurations(bn, discrete_ids, idx + 1)
pdf_val = pdf(dist, val)
println(
"DEBUG: subval=$subval, pdf_val=$pdf_val => partial = $(subval * pdf_val)",
)
total_prob += subval * pdf_val
end
delete!(bn.values, bn.names[node_id])
println("DEBUG: sum_discrete_configurations => total_prob=$total_prob")
return total_prob
end
end

###############################################################################
# 6) Create a log_posterior function
###############################################################################

function create_log_posterior(bn::BayesianNetwork)
println("DEBUG: create_log_posterior called for BN with $(nv(bn.graph)) nodes.")
function log_posterior(unobserved_values::Dict{Symbol,Float64})
println("DEBUG: log_posterior called with unobserved_values=$unobserved_values")
old_values = copy(bn.values)
try
# Merge the unobserved values into bn.values
for (k, v) in unobserved_values
bn.values[k] = v
println("DEBUG: Setting bn.values[$k] = $v")
end

# Identify unobserved, discrete nodes => must sum out
unobs_discrete_ids = Int[]
for sid in bn.stochastic_ids
if !bn.is_observed[sid]
varname = bn.names[sid]
if !haskey(bn.values, varname) && is_discrete_node(bn, sid)
push!(unobs_discrete_ids, sid)
end
end
end
println("DEBUG: unobs_discrete_ids = $unobs_discrete_ids")

# Optionally check for observed but incompatible values
for (varname, value) in bn.values
node_id = bn.names_to_ids[varname]
if bn.is_observed[node_id]
observed_dist = get_distribution(bn, node_id)
incompatible = (pdf(observed_dist, value) == 0.0)
println(
"DEBUG: Observed $varname=$value => dist=$observed_dist => pdf= $(pdf(observed_dist, value))",
)
if incompatible
println(
"DEBUG: Observed value is incompatible => returning -Inf"
)
return -Inf
end
end
end

if isempty(unobs_discrete_ids)
println("DEBUG: No discrete marginalization => direct logpdf")
lp = compute_full_logpdf(bn)
println("DEBUG: compute_full_logpdf => $lp")
return lp
else
println("DEBUG: Summing out discrete IDs => $unobs_discrete_ids")
prob_sum = sum_discrete_configurations(bn, unobs_discrete_ids, 1)
println("DEBUG: sum_discrete_configurations => $prob_sum")
return log(prob_sum)
end
finally
empty!(bn.values)
merge!(bn.values, old_values)
end
end
return log_posterior
end
Loading
Loading