Skip to content

Commit

Permalink
Store a state alongside the solver (#225)
Browse files Browse the repository at this point in the history
* storing a pair of solver and state to reuse the solver

* undoing accidental changes

* setting micm commit id

* removing references to singular stuff in fortran

* Update micm.hpp
  • Loading branch information
K20shores authored Oct 3, 2024
1 parent e1d6f9d commit b5ada9d
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 114 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.21)

# must be on the same line so that pyproject.toml can correctly identify the version
project(musica-distribution VERSION 0.8.0)
project(musica-distribution VERSION 0.8.1)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH};${PROJECT_SOURCE_DIR}/cmake)
set(CMAKE_USER_MAKE_RULES_OVERRIDE ${CMAKE_MODULE_PATH}/SetDefaults.cmake)
Expand Down
2 changes: 1 addition & 1 deletion cmake/dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ endif()
if (MUSICA_ENABLE_MICM AND MUSICA_BUILD_C_CXX_INTERFACE)

set_git_default(MICM_GIT_REPOSITORY https://github.com/NCAR/micm.git)
set_git_default(MICM_GIT_TAG v3.6.0)
set_git_default(MICM_GIT_TAG b3c462a)

FetchContent_Declare(micm
GIT_REPOSITORY ${MICM_GIT_REPOSITORY}
Expand Down
14 changes: 0 additions & 14 deletions fortran/micm.F90
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ module musica_micm
integer(c_int64_t) :: rejected_ = 0_c_int64_t
integer(c_int64_t) :: decompositions_ = 0_c_int64_t
integer(c_int64_t) :: solves_ = 0_c_int64_t
integer(c_int64_t) :: singular_ = 0_c_int64_t
real(c_double) :: final_time_ = 0._c_double
end type solver_stats_t_c

Expand Down Expand Up @@ -162,7 +161,6 @@ end function get_user_defined_reaction_rates_ordering_c
integer(int64) :: rejected_
integer(int64) :: decompositions_
integer(int64) :: solves_
integer(int64) :: singular_
real :: final_time_
contains
procedure :: function_calls => solver_stats_t_function_calls
Expand All @@ -172,7 +170,6 @@ end function get_user_defined_reaction_rates_ordering_c
procedure :: rejected => solver_stats_t_rejected
procedure :: decompositions => solver_stats_t_decompositions
procedure :: solves => solver_stats_t_solves
procedure :: singular => solver_stats_t_singular
procedure :: final_time => solver_stats_t_final_time
end type solver_stats_t

Expand Down Expand Up @@ -286,7 +283,6 @@ function solver_stats_t_constructor( c_solver_stats ) result( new_solver_stats )
new_solver_stats%rejected_ = c_solver_stats%rejected_
new_solver_stats%decompositions_ = c_solver_stats%decompositions_
new_solver_stats%solves_ = c_solver_stats%solves_
new_solver_stats%singular_ = c_solver_stats%singular_
new_solver_stats%final_time_ = real( c_solver_stats%final_time_ )

end function solver_stats_t_constructor
Expand Down Expand Up @@ -361,16 +357,6 @@ function solver_stats_t_solves( this ) result( solves )

end function solver_stats_t_solves

!> Get the number of times a singular matrix is detected
function solver_stats_t_singular( this ) result( singular )
use iso_fortran_env, only: int64
class(solver_stats_t), intent(in) :: this
integer(int64) :: singular

singular = this%function_calls_

end function solver_stats_t_singular

!> Get the final time the solver iterated to
function solver_stats_t_final_time( this ) result( final_time )
class(solver_stats_t), intent(in) :: this
Expand Down
1 change: 0 additions & 1 deletion fortran/test/fetch_content_integration/test_micm_api.F90
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ subroutine test_api()
write(*,*) "[test micm fort api] Rejected: ", solver_stats%rejected()
write(*,*) "[test micm fort api] Decompositions: ", solver_stats%decompositions()
write(*,*) "[test micm fort api] Solves: ", solver_stats%solves()
write(*,*) "[test micm fort api] Singular: ", solver_stats%singular()
write(*,*) "[test micm fort api] Final time: ", solver_stats%final_time()

string_value = micm%get_species_property_string( "O3", "__long name", error )
Expand Down
70 changes: 8 additions & 62 deletions include/musica/micm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <micm/util/sparse_matrix_vector_ordering.hpp>
#include <micm/util/vector_matrix.hpp>

#include <utility>
#include <chrono>
#include <cstddef>
#include <map>
#include <memory>
Expand Down Expand Up @@ -61,8 +63,6 @@ namespace musica
int64_t decompositions_{};
/// @brief The number of linear solves
int64_t solves_{};
/// @brief The number of times a singular matrix is detected.
int64_t singular_{};
/// @brief The final time the solver iterated to
double final_time_{};
/// @brief The final state the solver was in
Expand All @@ -75,7 +75,6 @@ namespace musica
rejected_(0),
decompositions_(0),
solves_(0),
singular_(0),
final_time_(0.0)
{
}
Expand All @@ -88,7 +87,6 @@ namespace musica
int64_t rejected,
int64_t decompositions,
int64_t solves,
int64_t singular,
double final_time)
: function_calls_(func_calls),
jacobian_updates_(jacobian),
Expand All @@ -97,7 +95,6 @@ namespace musica
rejected_(rejected),
decompositions_(decompositions),
solves_(solves),
singular_(singular),
final_time_(final_time)
{
}
Expand Down Expand Up @@ -192,7 +189,7 @@ namespace musica
void CreateBackwardEulerStandardOrder(const std::string &config_path, Error *error);

/// @brief Solve the system
/// @param solver Pointer to solver
/// @param solver_state_pair A pair containing a pointer to a solver and a state for that solver (temporary fix)
/// @param time_step Time [s] to advance the state by
/// @param temperature Temperature [grid cell] (K)
/// @param pressure Pressure [grid cell] (Pa)
Expand All @@ -201,7 +198,7 @@ namespace musica
/// @param custom_rate_parameters Array of custom rate parameters [grid cell][parameter] (various units)
/// @param error Error struct to indicate success or failure
void Solve(
auto &solver,
auto &solver_state_pair,
double time_step,
double *temperature,
double *pressure,
Expand All @@ -226,21 +223,6 @@ namespace musica
num_grid_cells_ = num_grid_cells;
}

/// @brief Get the ordering of species
/// @param solver Pointer to solver
/// @param error Error struct to indicate success or failure
/// @return Map of species names to their indices
// std::map<std::string, std::size_t> GetSpeciesOrdering(auto &solver, Error *error);
template<class T>
std::map<std::string, std::size_t> GetSpeciesOrdering(T &solver, Error *error);

/// @brief Get the ordering of user-defined reaction rates
/// @param solver Pointer to solver
/// @param error Error struct to indicate success or failure
/// @return Map of reaction rate names to their indices
template<class T>
std::map<std::string, std::size_t> GetUserDefinedReactionRatesOrdering(T &solver, Error *error);

/// @brief Get a property for a chemical species
/// @param species_name Name of the species
/// @param property_name Name of the property
Expand All @@ -259,7 +241,7 @@ namespace musica
template SolverType<micm::ProcessSet, micm::LinearSolver<SparseMatrixVector, micm::LuDecomposition>>;
using Rosenbrock = micm::Solver<RosenbrockVectorType, micm::State<DenseMatrixVector, SparseMatrixVector>>;
using VectorState = micm::State<DenseMatrixVector, SparseMatrixVector>;
std::unique_ptr<Rosenbrock> rosenbrock_;
std::pair<std::unique_ptr<Rosenbrock>, VectorState> rosenbrock_;

/// @brief Standard-ordered Rosenbrock solver type
using DenseMatrixStandard = micm::Matrix<double>;
Expand All @@ -268,20 +250,20 @@ namespace musica
template SolverType<micm::ProcessSet, micm::LinearSolver<SparseMatrixStandard, micm::LuDecomposition>>;
using RosenbrockStandard = micm::Solver<RosenbrockStandardType, micm::State<DenseMatrixStandard, SparseMatrixStandard>>;
using StandardState = micm::State<DenseMatrixStandard, SparseMatrixStandard>;
std::unique_ptr<RosenbrockStandard> rosenbrock_standard_;
std::pair<std::unique_ptr<RosenbrockStandard>, StandardState> rosenbrock_standard_;

/// @brief Vector-ordered Backward Euler
using BackwardEulerVectorType = typename micm::BackwardEulerSolverParameters::
template SolverType<micm::ProcessSet, micm::LinearSolver<SparseMatrixVector, micm::LuDecomposition>>;
using BackwardEuler = micm::Solver<BackwardEulerVectorType, micm::State<DenseMatrixVector, SparseMatrixVector>>;
std::unique_ptr<BackwardEuler> backward_euler_;
std::pair<std::unique_ptr<BackwardEuler>, VectorState> backward_euler_;

/// @brief Standard-ordered Backward Euler
using BackwardEulerStandardType = typename micm::BackwardEulerSolverParameters::
template SolverType<micm::ProcessSet, micm::LinearSolver<SparseMatrixStandard, micm::LuDecomposition>>;
using BackwardEulerStandard =
micm::Solver<BackwardEulerStandardType, micm::State<DenseMatrixStandard, SparseMatrixStandard>>;
std::unique_ptr<BackwardEulerStandard> backward_euler_standard_;
std::pair<std::unique_ptr<BackwardEulerStandard>, StandardState> backward_euler_standard_;

/// @brief Returns the number of grid cells
/// @return Number of grid cells
Expand All @@ -295,42 +277,6 @@ namespace musica
std::unique_ptr<micm::SolverParameters> solver_parameters_;
};

template<class T>
inline std::map<std::string, std::size_t> MICM::GetSpeciesOrdering(T &solver, Error *error)
{
try
{
micm::State state = solver->GetState();
DeleteError(error);
*error = NoError();
return state.variable_map_;
}
catch (const std::system_error &e)
{
DeleteError(error);
*error = ToError(e);
return std::map<std::string, std::size_t>();
}
}

template<class T>
inline std::map<std::string, std::size_t> MICM::GetUserDefinedReactionRatesOrdering(T &solver, Error *error)
{
try
{
micm::State state = solver->GetState();
DeleteError(error);
*error = NoError();
return state.custom_rate_parameter_map_;
}
catch (const std::system_error &e)
{
DeleteError(error);
*error = ToError(e);
return std::map<std::string, std::size_t>();
}
}

template<class T>
inline T MICM::GetSpeciesProperty(const std::string &species_name, const std::string &property_name, Error *error)
{
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cmake.args = [
"-DMUSICA_ENABLE_TUVX=OFF",
"-DMUSICA_BUILD_FORTRAN_INTERFACE=OFF",
"-DMUSICA_ENABLE_TESTS=OFF",
"-DCMAKE_BUILD_TYPE=Release"
]

[project.urls]
Expand Down
36 changes: 18 additions & 18 deletions python/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ namespace py = pybind11;
// Wraps micm.cpp
PYBIND11_MODULE(musica, m)
{
py::class_<musica::MICM>(m, "micm")
.def(py::init<>())
.def("__del__", [](musica::MICM &micm) {});
py::class_<musica::MICM>(m, "micm").def(py::init<>()).def("__del__", [](musica::MICM &micm) {});

py::enum_<musica::MICMSolver>(m, "micmsolver")
.value("rosenbrock", musica::MICMSolver::Rosenbrock)
.value("rosenbrock_standard_order", musica::MICMSolver::RosenbrockStandardOrder);
.value("rosenbrock", musica::MICMSolver::Rosenbrock)
.value("rosenbrock_standard_order", musica::MICMSolver::RosenbrockStandardOrder);

m.def(
"create_solver",
Expand Down Expand Up @@ -50,7 +48,7 @@ PYBIND11_MODULE(musica, m)
{
temperature_cpp.push_back(temperature.cast<double>());
}
else if(py::isinstance<py::list>(temperature))
else if (py::isinstance<py::list>(temperature))
{
py::list temperature_list = temperature.cast<py::list>();
temperature_cpp.reserve(len(temperature_list));
Expand All @@ -61,15 +59,16 @@ PYBIND11_MODULE(musica, m)
}
else
{
throw std::runtime_error("Temperature must be a list or a double. Got " + std::string(py::str(temperature.get_type()).cast<std::string>()));
throw std::runtime_error(
"Temperature must be a list or a double. Got " +
std::string(py::str(temperature.get_type()).cast<std::string>()));
}

std::vector<double> pressure_cpp;
if (py::isinstance<py::float_>(pressure))
{
pressure_cpp.push_back(pressure.cast<double>());
}
else if(py::isinstance<py::list>(pressure))
else if (py::isinstance<py::list>(pressure))
{
py::list pressure_list = pressure.cast<py::list>();
pressure_cpp.reserve(len(pressure_list));
Expand All @@ -80,14 +79,15 @@ PYBIND11_MODULE(musica, m)
}
else
{
throw std::runtime_error("Pressure must be a list or a double. Got " + std::string(py::str(pressure.get_type()).cast<std::string>()));
throw std::runtime_error(
"Pressure must be a list or a double. Got " + std::string(py::str(pressure.get_type()).cast<std::string>()));
}
std::vector<double> air_density_cpp;
if (py::isinstance<py::float_>(air_density))
{
air_density_cpp.push_back(air_density.cast<double>());
}
else if(py::isinstance<py::list>(air_density))
else if (py::isinstance<py::list>(air_density))
{
py::list air_density_list = air_density.cast<py::list>();
air_density_cpp.reserve(len(air_density_list));
Expand All @@ -98,16 +98,16 @@ PYBIND11_MODULE(musica, m)
}
else
{
throw std::runtime_error("Air density must be a list or a double. Got " + std::string(py::str(air_density.get_type()).cast<std::string>()));
throw std::runtime_error(
"Air density must be a list or a double. Got " +
std::string(py::str(air_density.get_type()).cast<std::string>()));
}

std::vector<double> concentrations_cpp;
concentrations_cpp.reserve(len(concentrations));
for (auto item : concentrations)
{
concentrations_cpp.push_back(item.cast<double>());
}

std::vector<double> custom_rate_parameters_cpp;
if (!custom_rate_parameters.is_none())
{
Expand Down Expand Up @@ -156,11 +156,11 @@ PYBIND11_MODULE(musica, m)

if (micm->solver_type_ == musica::MICMSolver::Rosenbrock)
{
map = micm->GetSpeciesOrdering(micm->rosenbrock_, &error);
map = micm->rosenbrock_.second.variable_map_;
}
else if (micm->solver_type_ == musica::MICMSolver::RosenbrockStandardOrder)
{
map = micm->GetSpeciesOrdering(micm->rosenbrock_standard_, &error);
map = micm->rosenbrock_standard_.second.variable_map_;
}

return map;
Expand All @@ -176,11 +176,11 @@ PYBIND11_MODULE(musica, m)

if (micm->solver_type_ == musica::MICMSolver::Rosenbrock)
{
map = micm->GetUserDefinedReactionRatesOrdering(micm->rosenbrock_, &error);
map = micm->rosenbrock_.second.custom_rate_parameter_map_;
}
else if (micm->solver_type_ == musica::MICMSolver::RosenbrockStandardOrder)
{
map = micm->GetUserDefinedReactionRatesOrdering(micm->rosenbrock_standard_, &error);
map = micm->rosenbrock_standard_.second.custom_rate_parameter_map_;
}

return map;
Expand Down
Loading

0 comments on commit b5ada9d

Please sign in to comment.