Skip to content

Commit 9a5d267

Browse files
committed
pgc state object can now contain arbitrary attributes
1 parent efd58f9 commit 9a5d267

File tree

3 files changed

+112
-52
lines changed

3 files changed

+112
-52
lines changed

stormvogel/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class Action:
205205
labels: The labels of this action. Corresponds to Storm labels.
206206
"""
207207

208-
name: str
208+
name: str # TODO name is stormpy choice label or we don't need a name at all?
209209
labels: frozenset[str]
210210

211211
def __str__(self):

stormvogel/pgc.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,36 @@
33
from typing import Callable
44

55

6+
@dataclass
7+
class Action:
8+
"""pgc action object. Contains a list of labels"""
9+
10+
labels: list[str]
11+
12+
613
@dataclass
714
class State:
8-
f: dict[str, int]
15+
"""pgc state object. Can contain any number of any type of arguments"""
16+
17+
def __init__(self, **kwargs):
18+
for key, value in kwargs.items():
19+
setattr(self, key, value)
20+
21+
def __repr__(self):
22+
return f"State({self.__dict__})"
923

1024
def __hash__(self):
11-
return hash(self.x)
25+
return hash(self.__dict__)
1226

1327
def __eq__(self, other):
1428
if isinstance(other, State):
15-
return self.f == other.f
29+
return self.__dict__ == other.__dict__
1630
return False
1731

1832

19-
@dataclass
20-
class Action:
21-
labels: list[str]
22-
23-
2433
def build_pgc(
2534
delta, # Callable[[State, Action], list[tuple[float, State]]],
26-
initial_state_pgc: State,
35+
initial_state_pgc: State, # TODO rewards function, label function
2736
available_actions: Callable[[State], list[Action]] | None = None,
2837
modeltype: stormvogel.model.ModelType = stormvogel.model.ModelType.MDP,
2938
) -> stormvogel.model.Model:
@@ -46,7 +55,9 @@ def build_pgc(
4655

4756
# we create the model with the given type and initial state
4857
model.new_state(
49-
labels=["init"], features=initial_state_pgc.f, name=str(initial_state_pgc.f)
58+
labels=["init"],
59+
features=initial_state_pgc.__dict__,
60+
name=str(initial_state_pgc.__dict__),
5061
)
5162

5263
# we continue calling delta and adding new states until no states are
@@ -82,14 +93,15 @@ def build_pgc(
8293
if tuple[1] not in states_seen:
8394
states_seen.append(tuple[1])
8495
new_state = model.new_state(
85-
name=str(tuple[1].f), features=tuple[1].f
96+
name=str(tuple[1].__dict__), features=tuple[1].__dict__
8697
)
8798
branch.append((tuple[0], new_state))
8899
states_to_be_visited.append(tuple[1])
89100
else:
90-
# TODO what if there are multiple states with the same label? use names?
101+
# print(tuple[1].__dict__)
102+
# print(model.states)
91103
branch.append(
92-
(tuple[0], model.get_state_by_name(str(tuple[1].f)))
104+
(tuple[0], model.get_state_by_name(str(tuple[1].__dict__)))
93105
)
94106
if branch != []:
95107
transition[stormvogel_action] = stormvogel.model.Branch(branch)
@@ -101,19 +113,20 @@ def build_pgc(
101113
if tuple[1] not in states_seen:
102114
states_seen.append(tuple[1])
103115
new_state = model.new_state(
104-
name=str(tuple[1].f), features=tuple[1].f
116+
name=str(tuple[1].__dict__), features=tuple[1].__dict__
105117
)
106118

107119
branch.append((tuple[0], new_state))
108120
states_to_be_visited.append(tuple[1])
109121
else:
110-
# TODO what if there are multiple states with the same label? use names?
111-
branch.append((tuple[0], model.get_state_by_name(str(tuple[1].f))))
122+
branch.append(
123+
(tuple[0], model.get_state_by_name(str(tuple[1].__dict__)))
124+
)
112125
if branch != []:
113126
transition[stormvogel.model.EmptyAction] = stormvogel.model.Branch(
114127
branch
115128
)
116-
s = model.get_state_by_name(str(state.f))
129+
s = model.get_state_by_name(str(state.__dict__))
117130
assert s is not None
118131
model.add_transitions(
119132
s,

tests/test_pgc.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from stormvogel import pgc
22
import stormvogel.model
33
import math
4-
import stormpy
54
import stormvogel.mapping
65

76

87
def test_pgc_mdp():
98
# we build the model with pgc:
109
N = 2
1110
p = 0.5
12-
initial_state = pgc.State({"x": math.floor(N / 2)})
11+
initial_state = pgc.State(x=math.floor(N / 2))
1312

1413
left = pgc.Action(["left"])
1514
right = pgc.Action(["right"])
@@ -19,23 +18,21 @@ def available_actions(s: pgc.State):
1918

2019
def delta(s: pgc.State, action: pgc.Action):
2120
if action == left:
22-
print(s.f)
23-
print(s.f["x"])
2421
return (
2522
[
26-
(p, pgc.State({"x": s.f["x"] + 1})),
27-
(1 - p, pgc.State({"x": s.f["x"]})),
23+
(p, pgc.State(x=s.x + 1)),
24+
(1 - p, pgc.State(x=s.x)),
2825
]
29-
if s.f["x"] < N
26+
if s.x < N
3027
else []
3128
)
3229
elif action == right:
3330
return (
3431
[
35-
(p, pgc.State({"x": s.f["x"] - 1})),
36-
(1 - p, pgc.State({"x": s.f["x"]})),
32+
(p, pgc.State(x=s.x - 1)),
33+
(1 - p, pgc.State(x=s.x)),
3734
]
38-
if s.f["x"] > 0
35+
if s.x > 0
3936
else []
4037
)
4138

@@ -63,41 +60,38 @@ def delta(s: pgc.State, action: pgc.Action):
6360
model.add_transitions(state2, stormvogel.model.Transition({right: branch2}))
6461
model.add_transitions(state0, stormvogel.model.Transition({left: branch0}))
6562

66-
print(model)
67-
print(pgc_model)
68-
6963
assert model == pgc_model
7064

7165

7266
def test_pgc_dtmc():
7367
# we build the model with pgc:
7468
p = 0.5
75-
initial_state = pgc.State({"s": 0})
69+
initial_state = pgc.State(s=0)
7670

7771
def delta(s: pgc.State):
78-
match s.f["s"]:
72+
match s.s:
7973
case 0:
80-
return [(p, pgc.State({"s": 1})), (1 - p, pgc.State({"s": 2}))]
74+
return [(p, pgc.State(s=1)), (1 - p, pgc.State(s=2))]
8175
case 1:
82-
return [(p, pgc.State({"s": 3})), (1 - p, pgc.State({"s": 4}))]
76+
return [(p, pgc.State(s=3)), (1 - p, pgc.State(s=4))]
8377
case 2:
84-
return [(p, pgc.State({"s": 5})), (1 - p, pgc.State({"s": 6}))]
78+
return [(p, pgc.State(s=5)), (1 - p, pgc.State(s=6))]
8579
case 3:
86-
return [(p, pgc.State({"s": 1})), (1 - p, pgc.State({"s": 7, "d": 1}))]
80+
return [(p, pgc.State(s=1)), (1 - p, pgc.State(s=7, d=1))]
8781
case 4:
8882
return [
89-
(p, pgc.State({"s": 7, "d": 2})),
90-
(1 - p, pgc.State({"s": 7, "d": 3})),
83+
(p, pgc.State(s=7, d=2)),
84+
(1 - p, pgc.State(s=7, d=3)),
9185
]
9286
case 5:
9387
return [
94-
(p, pgc.State({"s": 7, "d": 4})),
95-
(1 - p, pgc.State({"s": 7, "d": 5})),
88+
(p, pgc.State(s=7, d=4)),
89+
(1 - p, pgc.State(s=7, d=5)),
9690
]
9791
case 6:
98-
return [(p, pgc.State({"s": 2})), (1 - p, pgc.State({"s": 7, "d": 6}))]
92+
return [(p, pgc.State(s=2)), (1 - p, pgc.State(s=7, d=6))]
9993
case 7:
100-
return [(1, pgc.State({"s": 7}))]
94+
return [(1, pgc.State(s=7))]
10195

10296
pgc_model = stormvogel.pgc.build_pgc(
10397
delta=delta,
@@ -106,12 +100,65 @@ def delta(s: pgc.State):
106100
)
107101

108102
# we build the model in the regular way:
109-
path = stormpy.examples.files.prism_dtmc_die
110-
prism_program = stormpy.parse_prism_program(path)
111-
formula_str = "P=? [F s=7 & d=2]"
112-
properties = stormpy.parse_properties(formula_str, prism_program)
113-
model = stormpy.build_model(prism_program, properties)
114-
print(dir(model.states[0]))
115-
stormvogel_model = stormvogel.mapping.stormpy_to_stormvogel(model)
116-
print(pgc_model)
117-
print(stormvogel_model)
103+
model = stormvogel.model.new_dtmc()
104+
model.states[0].features = {"s": 0}
105+
model.set_transitions(
106+
model.get_initial_state(),
107+
[
108+
(1 / 2, model.new_state(features={"s": 1})),
109+
(1 / 2, model.new_state(features={"s": 2})),
110+
],
111+
)
112+
model.set_transitions(
113+
model.get_state_by_id(1),
114+
[
115+
(1 / 2, model.new_state(features={"s": 3})),
116+
(1 / 2, model.new_state(features={"s": 4})),
117+
],
118+
)
119+
model.set_transitions(
120+
model.get_state_by_id(2),
121+
[
122+
(1 / 2, model.new_state(features={"s": 5})),
123+
(1 / 2, model.new_state(features={"s": 6})),
124+
],
125+
)
126+
model.set_transitions(
127+
model.get_state_by_id(3),
128+
[
129+
(1 / 2, model.get_state_by_id(1)),
130+
(1 / 2, model.new_state(features={"s": 7, "d": 1})),
131+
],
132+
)
133+
model.set_transitions(
134+
model.get_state_by_id(4),
135+
[
136+
(1 / 2, model.new_state(features={"s": 7, "d": 2})),
137+
(1 / 2, model.new_state(features={"s": 7, "d": 3})),
138+
],
139+
)
140+
model.set_transitions(
141+
model.get_state_by_id(5),
142+
[
143+
(1 / 2, model.new_state(features={"s": 7, "d": 4})),
144+
(1 / 2, model.new_state(features={"s": 7, "d": 5})),
145+
],
146+
)
147+
model.set_transitions(
148+
model.get_state_by_id(6),
149+
[
150+
(1 / 2, model.get_state_by_id(2)),
151+
(1 / 2, model.new_state(features={"s": 7, "d": 6})),
152+
],
153+
)
154+
model.set_transitions(
155+
model.get_state_by_id(7), [(1, model.new_state(features={"s": 7}))]
156+
)
157+
model.set_transitions(model.get_state_by_id(8), [(1, model.get_state_by_id(13))])
158+
model.set_transitions(model.get_state_by_id(9), [(1, model.get_state_by_id(13))])
159+
model.set_transitions(model.get_state_by_id(10), [(1, model.get_state_by_id(13))])
160+
model.set_transitions(model.get_state_by_id(11), [(1, model.get_state_by_id(13))])
161+
model.set_transitions(model.get_state_by_id(12), [(1, model.get_state_by_id(13))])
162+
model.set_transitions(model.get_state_by_id(13), [(1, model.get_state_by_id(13))])
163+
164+
assert pgc_model == model

0 commit comments

Comments
 (0)