From 333ac69ee9ad0e2336fde119cac9d2d776f99c5c Mon Sep 17 00:00:00 2001 From: David Anthoff Date: Fri, 8 Mar 2024 13:37:03 -0800 Subject: [PATCH] Fix a performance problem with adding RVs --- src/mcs/defmcs.jl | 14 +++++++++++--- src/mcs/mcs_types.jl | 2 +- src/mcs/montecarlo.jl | 8 ++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/mcs/defmcs.jl b/src/mcs/defmcs.jl index 1d55759c2..b1842f443 100644 --- a/src/mcs/defmcs.jl +++ b/src/mcs/defmcs.jl @@ -200,12 +200,20 @@ end # function _update_nt_type!(sim_def::SimulationDef{T}) where T <: AbstractSimulationData - names = (keys(sim_def.rvdict)...,) - types = [eltype(fld) for fld in values(sim_def.rvdict)] - sim_def.nt_type = NamedTuple{names, Tuple{types...}} + sim_def.nt_type = nothing nothing end +function _get_nt_type(sim_def::SimulationDef{T}) where T <: AbstractSimulationData + if sim_def.nt_type===nothing + names = (keys(sim_def.rvdict)...,) + types = [eltype(fld) for fld in values(sim_def.rvdict)] + sim_def.nt_type = NamedTuple{names, Tuple{types...}} + end + + return sim_def.nt_type +end + """ delete_RV!(sim_def::SimulationDef, name::Symbol) diff --git a/src/mcs/mcs_types.jl b/src/mcs/mcs_types.jl index a57138087..08662b25f 100644 --- a/src/mcs/mcs_types.jl +++ b/src/mcs/mcs_types.jl @@ -162,7 +162,7 @@ mutable struct SimulationDef{T} names = (keys(self.rvdict)...,) types = [eltype(fld) for fld in values(self.rvdict)] - self.nt_type = NamedTuple{names, Tuple{types...}} + self.nt_type = nothing self.data = data self.payload = nothing diff --git a/src/mcs/montecarlo.jl b/src/mcs/montecarlo.jl index 979140967..9de29390e 100644 --- a/src/mcs/montecarlo.jl +++ b/src/mcs/montecarlo.jl @@ -26,7 +26,7 @@ function Base.show(io::IO, sim_def::SimulationDef{T}) where T <: AbstractSimulat print_nonempty("translist", sim_def.translist) print_nonempty("savelist", sim_def.savelist) - println(" nt_type: $(sim_def.nt_type)") + println(" nt_type: $(_get_nt_type(sim_def))") Base.show(io, sim_def.data) # note: data::T end @@ -199,7 +199,7 @@ function get_trial(sim_inst::SimulationInstance, trialnum::Int) sim_def = sim_inst.sim_def vals = [rand(rv.dist) for rv in values(sim_def.rvdict)] - sim_inst.current_data = sim_def.nt_type((vals...,)) + sim_inst.current_data = _get_nt_type(sim_def)((vals...,)) sim_inst.current_trial = trialnum return sim_inst.current_data @@ -827,9 +827,9 @@ end IteratorInterfaceExtensions.isiterable(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = true TableTraits.isiterabletable(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = true -IteratorInterfaceExtensions.getiterator(sim_inst::SimulationInstance{T}) where T = SimIterator{sim_inst.sim_def.nt_type, T}(sim_inst) +IteratorInterfaceExtensions.getiterator(sim_inst::SimulationInstance{T}) where T = SimIterator{_get_nt_type(sim_inst.sim_def), T}(sim_inst) -column_names(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = fieldnames(sim_def.nt_type) +column_names(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = fieldnames(_get_nt_type(sim_def)) column_types(sim_def::SimulationDef{T}) where T <: AbstractSimulationData = [eltype(fld) for fld in values(sim_def.rvdict)] column_names(sim_inst::SimulationInstance{T}) where T <: AbstractSimulationData = column_names(sim_inst.sim_def)