Skip to content

Commit

Permalink
Added convenience functions for state valuations (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
volkm authored Dec 29, 2024
2 parents ad46d5e + d9a278e commit c82fee7
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/storage/prism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ void define_prism(py::module& m) {
program.def_property_readonly("constants", &Program::getConstants, "Get Program Constants")
.def_property_readonly("global_boolean_variables", &Program::getGlobalBooleanVariables, "Retrieves the global boolean variables of the program")
.def_property_readonly("global_integer_variables", &Program::getGlobalIntegerVariables, "Retrieves the global integer variables of the program")
.def_property_readonly("variables", &Program::getAllExpressionVariables, "Retrieves all expression variables (including constants) of the program")
.def("get_variables", &Program::getAllExpressionVariables, py::arg("include_constants") = true, "Get all expression variables (and constants) used by the program")
.def_property_readonly("nr_modules", &storm::prism::Program::getNumberOfModules, "Number of modules")
.def_property_readonly("modules", &storm::prism::Program::getModules, "Modules in the program")
.def_property_readonly("model_type", &storm::prism::Program::getModelType, "Model type")
Expand Down Expand Up @@ -71,7 +73,6 @@ void define_prism(py::module& m) {
return program.toJani(properties, allVariablesGlobal, suffix);
}, "Transform to Jani program", py::arg("properties"), py::arg("all_variables_global") = true, py::arg("suffix") = "")
.def("__str__", &streamToString<storm::prism::Program>)
.def_property_readonly("variables", &Program::getAllExpressionVariables, "Get all Expression Variables used by the program")
.def("get_label_expression", [](storm::prism::Program const& program, std::string const& label){
return program.getLabelExpression(label);
}, "Get the expression of the given label.", py::arg("label"))
Expand Down
3 changes: 3 additions & 0 deletions src/storage/valuation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void define_statevaluation(py::module& m) {
.def("get_boolean_value", &storm::storage::sparse::StateValuations::getBooleanValue, py::arg("state"), py::arg("variable"))
.def("get_integer_value", &storm::storage::sparse::StateValuations::getIntegerValue, py::arg("state"), py::arg("variable"))
.def("get_rational_value", &storm::storage::sparse::StateValuations::getRationalValue, py::arg("state"), py::arg("variable"))
.def("get_boolean_values_states", &storm::storage::sparse::StateValuations::getBooleanValues, py::arg("variable"), "Get the value of the Boolean variable of all states. The i'th entry represents the value of state i.")
.def("get_integer_values_states", &storm::storage::sparse::StateValuations::getIntegerValues, py::arg("variable"), "Get the value of the integer variable of all states. The i'th entry represents the value of state i.")
.def("get_rational_values_states", &storm::storage::sparse::StateValuations::getRationalValues, py::arg("variable"), "Get the value of the rational variable of all states. The i'th entry represents the value of state i.")
.def("get_string", &storm::storage::sparse::StateValuations::toString, py::arg("state"), py::arg("pretty")=true, py::arg("selected_variables")=boost::none)
.def("get_json", &toJson, py::arg("state"), py::arg("selected_variables")=boost::none)
.def("get_nr_of_states", &storm::storage::sparse::StateValuations::getNumberOfStates);
Expand Down
16 changes: 12 additions & 4 deletions tests/storage/test_prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ def test_prism_to_jani_repetitive(self):
assert len(new_properties) == len(orig_properties)

def test_prism_variables(selfs):
program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm"))
module = program.modules[0]
assert len(module.integer_variables) == 2
assert len(module.boolean_variables) == 0
program = stormpy.parse_prism_program(get_example_path("mdp", "leader3.nm"))
assert program.nr_modules == 3
assert program.model_type == stormpy.PrismModelType.MDP
assert not program.has_undefined_constants
assert len(program.constants) == 1
assert len(program.global_boolean_variables) == 0
assert len(program.global_integer_variables) == 0
assert len(program.get_variables()) == 16
assert len(program.get_variables(False)) == 15
for module in program.modules:
assert len(module.integer_variables) == 5
assert len(module.boolean_variables) == 0
22 changes: 22 additions & 0 deletions tests/storage/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@ def test_states_dtmc(self):
state = states[5]
assert state.id == 5

def test_state_valuations(self):
program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm"))
options = stormpy.BuilderOptions()
options.set_build_state_valuations()
model = stormpy.build_sparse_model_with_options(program, options)
assert model.has_state_valuations()
for var in program.get_variables():
if var.name == "s":
var_s = var
elif var.name == "d":
var_d = var
else:
assert False
# Values of s should be 0, ..., 6 and then 6 times 7
vals_s = model.state_valuations.get_integer_values_states(var_s)
comp_s = [i for i in range(0, 7)] + [7] * 6
assert vals_s == comp_s
# Values of d should be 7 times 0 and then 1, ..., 6
vals_d = model.state_valuations.get_integer_values_states(var_d)
comp_d = [0] * 7 + [i for i in range(1, 7)]
assert vals_d == comp_d

def test_states_mdp(self):
model = stormpy.build_sparse_model_from_explicit(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
i = 0
Expand Down
4 changes: 2 additions & 2 deletions tests/storage/test_state_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def _load_program(filename):
program = program.substitute_constants()

expression_parser = stormpy.ExpressionParser(program.expression_manager)
expression_parser.set_identifier_mapping({var.name: var.get_expression() for var in program.variables})
expression_parser.set_identifier_mapping({var.name: var.get_expression() for var in program.get_variables()})
return program, expression_parser


def _find_variable(program, name):
for var in program.variables:
for var in program.get_variables():
if var.name is name:
return var
return None
Expand Down

0 comments on commit c82fee7

Please sign in to comment.