Skip to content

Commit

Permalink
Merge pull request #986 from mimiframework/fix-perf-rv
Browse files Browse the repository at this point in the history
Fix a performance problem with adding RVs
  • Loading branch information
davidanthoff authored Mar 8, 2024
2 parents 1bdbc78 + 333ac69 commit 9fcc462
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
14 changes: 11 additions & 3 deletions src/mcs/defmcs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,20 @@ end
#

function _update_nt_type!(sim_def::SimulationDef{T}) where T <: AbstractSimulationData
names = (keys(sim_def.rvdict)...,)
types = [eltype(fld) for fld in values(sim_def.rvdict)]
sim_def.nt_type = NamedTuple{names, Tuple{types...}}
sim_def.nt_type = nothing
nothing
end

function _get_nt_type(sim_def::SimulationDef{T}) where T <: AbstractSimulationData
if sim_def.nt_type===nothing
names = (keys(sim_def.rvdict)...,)
types = [eltype(fld) for fld in values(sim_def.rvdict)]
sim_def.nt_type = NamedTuple{names, Tuple{types...}}
end

return sim_def.nt_type
end

"""
delete_RV!(sim_def::SimulationDef, name::Symbol)
Expand Down
2 changes: 1 addition & 1 deletion src/mcs/mcs_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ mutable struct SimulationDef{T}

names = (keys(self.rvdict)...,)
types = [eltype(fld) for fld in values(self.rvdict)]
self.nt_type = NamedTuple{names, Tuple{types...}}
self.nt_type = nothing

self.data = data
self.payload = nothing
Expand Down
8 changes: 4 additions & 4 deletions src/mcs/montecarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function Base.show(io::IO, sim_def::SimulationDef{T}) where T <: AbstractSimulat

print_nonempty("translist", sim_def.translist)
print_nonempty("savelist", sim_def.savelist)
println(" nt_type: $(sim_def.nt_type)")
println(" nt_type: $(_get_nt_type(sim_def))")

Base.show(io, sim_def.data) # note: data::T
end
Expand Down Expand Up @@ -199,7 +199,7 @@ function get_trial(sim_inst::SimulationInstance, trialnum::Int)
sim_def = sim_inst.sim_def

vals = [rand(rv.dist) for rv in values(sim_def.rvdict)]
sim_inst.current_data = sim_def.nt_type((vals...,))
sim_inst.current_data = _get_nt_type(sim_def)((vals...,))
sim_inst.current_trial = trialnum

return sim_inst.current_data
Expand Down Expand Up @@ -827,9 +827,9 @@ end
IteratorInterfaceExtensions.isiterable(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = true
TableTraits.isiterabletable(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = true

IteratorInterfaceExtensions.getiterator(sim_inst::SimulationInstance{T}) where T = SimIterator{sim_inst.sim_def.nt_type, T}(sim_inst)
IteratorInterfaceExtensions.getiterator(sim_inst::SimulationInstance{T}) where T = SimIterator{_get_nt_type(sim_inst.sim_def), T}(sim_inst)

column_names(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = fieldnames(sim_def.nt_type)
column_names(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = fieldnames(_get_nt_type(sim_def))
column_types(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = [eltype(fld) for fld in values(sim_def.rvdict)]

column_names(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = column_names(sim_inst.sim_def)
Expand Down

0 comments on commit 9fcc462

Please sign in to comment.