Skip to content

Commit c9f26e3

Browse files
authored
fix: sqlalchemy session timeout and InvalidRequestError
1 parent 218b193 commit c9f26e3

File tree

2 files changed

+94
-86
lines changed

2 files changed

+94
-86
lines changed

casbin_sqlalchemy_adapter/adapter.py

Lines changed: 65 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from contextlib import contextmanager
2+
13
from casbin import persist
24
from sqlalchemy import Column, Integer, String
3-
from sqlalchemy import create_engine, and_, or_
5+
from sqlalchemy import create_engine, or_
46
from sqlalchemy.ext.declarative import declarative_base
57
from sqlalchemy.orm import sessionmaker
68

@@ -53,31 +55,44 @@ def __init__(self, engine, db_class=None, filtered=False):
5355
if db_class is None:
5456
db_class = CasbinRule
5557
self._db_class = db_class
56-
session = sessionmaker(bind=self._engine)
57-
self._session = session()
58+
self.session_local = sessionmaker(bind=self._engine)
5859

5960
Base.metadata.create_all(self._engine)
6061
self._filtered = filtered
6162

63+
@contextmanager
64+
def _session_scope(self):
65+
"""Provide a transactional scope around a series of operations."""
66+
session = self.session_local()
67+
try:
68+
yield session
69+
session.commit()
70+
except Exception as e:
71+
session.rollback()
72+
raise e
73+
finally:
74+
session.close()
75+
6276
def load_policy(self, model):
6377
"""loads all policy rules from the storage."""
64-
lines = self._session.query(self._db_class).all()
65-
for line in lines:
66-
persist.load_policy_line(str(line), model)
67-
self._commit()
78+
with self._session_scope() as session:
79+
lines = session.query(self._db_class).all()
80+
for line in lines:
81+
persist.load_policy_line(str(line), model)
6882

6983
def is_filtered(self):
7084
return self._filtered
7185

7286
def load_filtered_policy(self, model, filter) -> None:
7387
"""loads all policy rules from the storage."""
74-
query = self._session.query(self._db_class)
75-
filters = self.filter_query(query, filter)
76-
filters = filters.all()
88+
with self._session_scope() as session:
89+
query = session.query(self._db_class)
90+
filters = self.filter_query(query, filter)
91+
filters = filters.all()
7792

78-
for line in filters:
79-
persist.load_policy_line(str(line), model)
80-
self._filtered = True
93+
for line in filters:
94+
persist.load_policy_line(str(line), model)
95+
self._filtered = True
8196

8297
def filter_query(self, querydb, filter):
8398
if len(filter.ptype) > 0:
@@ -97,76 +112,68 @@ def filter_query(self, querydb, filter):
97112
return querydb.order_by(CasbinRule.id)
98113

99114
def _save_policy_line(self, ptype, rule):
100-
line = self._db_class(ptype=ptype)
101-
for i, v in enumerate(rule):
102-
setattr(line, "v{}".format(i), v)
103-
self._session.add(line)
104-
105-
def _commit(self):
106-
self._session.commit()
115+
with self._session_scope() as session:
116+
line = self._db_class(ptype=ptype)
117+
for i, v in enumerate(rule):
118+
setattr(line, "v{}".format(i), v)
119+
session.add(line)
107120

108121
def save_policy(self, model):
109122
"""saves all policy rules to the storage."""
110-
query = self._session.query(self._db_class)
111-
query.delete()
112-
for sec in ["p", "g"]:
113-
if sec not in model.model.keys():
114-
continue
115-
for ptype, ast in model.model[sec].items():
116-
for rule in ast.policy:
117-
self._save_policy_line(ptype, rule)
118-
self._commit()
123+
with self._session_scope() as session:
124+
query = session.query(self._db_class)
125+
query.delete()
126+
for sec in ["p", "g"]:
127+
if sec not in model.model.keys():
128+
continue
129+
for ptype, ast in model.model[sec].items():
130+
for rule in ast.policy:
131+
self._save_policy_line(ptype, rule)
119132
return True
120133

121134
def add_policy(self, sec, ptype, rule):
122135
"""adds a policy rule to the storage."""
123136
self._save_policy_line(ptype, rule)
124-
self._commit()
125137

126138
def add_policies(self, sec, ptype, rules):
127139
"""adds a policy rules to the storage."""
128140
for rule in rules:
129141
self._save_policy_line(ptype, rule)
130-
self._commit()
131142

132143
def remove_policy(self, sec, ptype, rule):
133144
"""removes a policy rule from the storage."""
134-
query = self._session.query(self._db_class)
135-
query = query.filter(self._db_class.ptype == ptype)
136-
for i, v in enumerate(rule):
137-
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
138-
r = query.delete()
139-
self._commit()
145+
with self._session_scope() as session:
146+
query = session.query(self._db_class)
147+
query = query.filter(self._db_class.ptype == ptype)
148+
for i, v in enumerate(rule):
149+
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
150+
r = query.delete()
140151

141152
return True if r > 0 else False
142153

143154
def remove_policies(self, sec, ptype, rules):
144155
"""removes a policy rules from the storage."""
145-
query = self._session.query(self._db_class)
146-
query = query.filter(self._db_class.ptype == ptype)
147-
for rule in rules:
148-
query = query.filter(or_(getattr(self._db_class, "v{}".format(i)) == v for i, v in enumerate(rule)))
149-
query.delete()
150-
self._commit()
151-
156+
with self._session_scope() as session:
157+
query = session.query(self._db_class)
158+
query = query.filter(self._db_class.ptype == ptype)
159+
for rule in rules:
160+
query = query.filter(or_(getattr(self._db_class, "v{}".format(i)) == v for i, v in enumerate(rule)))
161+
query.delete()
152162

153163
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
154164
"""removes policy rules that match the filter from the storage.
155165
This is part of the Auto-Save feature.
156166
"""
157-
query = self._session.query(self._db_class)
158-
query = query.filter(self._db_class.ptype == ptype)
159-
if not (0 <= field_index <= 5):
160-
return False
161-
if not (1 <= field_index + len(field_values) <= 6):
162-
return False
163-
for i, v in enumerate(field_values):
164-
if v != '':
165-
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
166-
r = query.delete()
167-
self._commit()
167+
with self._session_scope() as session:
168+
query = session.query(self._db_class)
169+
query = query.filter(self._db_class.ptype == ptype)
170+
if not (0 <= field_index <= 5):
171+
return False
172+
if not (1 <= field_index + len(field_values) <= 6):
173+
return False
174+
for i, v in enumerate(field_values):
175+
if v != '':
176+
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
177+
r = query.delete()
168178

169179
return True if r > 0 else False
170-
171-
def __del__(self):
172-
self._session.close()

tests/test_adapter.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import casbin
99
import os
1010

11+
1112
def get_fixture(path):
1213
dir_path = os.path.split(os.path.realpath(__file__))[0] + "/"
1314
return os.path.abspath(dir_path + path)
@@ -140,7 +141,7 @@ def test_str(self):
140141
self.assertEqual(str(rule), 'p, data2_admin, data2, read')
141142
rule = CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='write')
142143
self.assertEqual(str(rule), 'p, data2_admin, data2, write')
143-
rule = CasbinRule(ptype='g', v0='alice', v1 = 'data2_admin')
144+
rule = CasbinRule(ptype='g', v0='alice', v1='data2_admin')
144145
self.assertEqual(str(rule), 'g, alice, data2_admin')
145146

146147
def test_repr(self):
@@ -158,13 +159,13 @@ def test_repr(self):
158159
s.close()
159160

160161
def test_filtered_policy(self):
161-
e= get_enforcer()
162+
e = get_enforcer()
162163
filter = Filter()
163164

164165
filter.ptype = ['p']
165166
e.load_filtered_policy(filter)
166167
self.assertTrue(e.enforce('alice', 'data1', 'read'))
167-
self.assertFalse(e.enforce('alice','data1','write'))
168+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
168169
self.assertFalse(e.enforce('alice', 'data2', 'read'))
169170
self.assertFalse(e.enforce('alice', 'data2', 'write'))
170171
self.assertFalse(e.enforce('bob', 'data1', 'read'))
@@ -176,105 +177,105 @@ def test_filtered_policy(self):
176177
filter.v0 = ['alice']
177178
e.load_filtered_policy(filter)
178179
self.assertTrue(e.enforce('alice', 'data1', 'read'))
179-
self.assertFalse(e.enforce('alice','data1','write'))
180+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
180181
self.assertFalse(e.enforce('alice', 'data2', 'read'))
181182
self.assertFalse(e.enforce('alice', 'data2', 'write'))
182183
self.assertFalse(e.enforce('bob', 'data1', 'read'))
183184
self.assertFalse(e.enforce('bob', 'data1', 'write'))
184185
self.assertFalse(e.enforce('bob', 'data2', 'read'))
185186
self.assertFalse(e.enforce('bob', 'data2', 'write'))
186-
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
187-
self.assertFalse(e.enforce('data2_admin', 'data2','write'))
187+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
188+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
188189

189190
filter.v0 = ['bob']
190191
e.load_filtered_policy(filter)
191192
self.assertFalse(e.enforce('alice', 'data1', 'read'))
192-
self.assertFalse(e.enforce('alice','data1','write'))
193+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
193194
self.assertFalse(e.enforce('alice', 'data2', 'read'))
194195
self.assertFalse(e.enforce('alice', 'data2', 'write'))
195196
self.assertFalse(e.enforce('bob', 'data1', 'read'))
196197
self.assertFalse(e.enforce('bob', 'data1', 'write'))
197198
self.assertFalse(e.enforce('bob', 'data2', 'read'))
198199
self.assertTrue(e.enforce('bob', 'data2', 'write'))
199-
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
200-
self.assertFalse(e.enforce('data2_admin', 'data2','write'))
200+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
201+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
201202

202203
filter.v0 = ['data2_admin']
203204
e.load_filtered_policy(filter)
204-
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
205-
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
205+
self.assertTrue(e.enforce('data2_admin', 'data2', 'read'))
206+
self.assertTrue(e.enforce('data2_admin', 'data2', 'read'))
206207
self.assertFalse(e.enforce('alice', 'data1', 'read'))
207-
self.assertFalse(e.enforce('alice','data1','write'))
208+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
208209
self.assertFalse(e.enforce('alice', 'data2', 'read'))
209210
self.assertFalse(e.enforce('alice', 'data2', 'write'))
210211
self.assertFalse(e.enforce('bob', 'data1', 'read'))
211212
self.assertFalse(e.enforce('bob', 'data1', 'write'))
212213
self.assertFalse(e.enforce('bob', 'data2', 'read'))
213214
self.assertFalse(e.enforce('bob', 'data2', 'write'))
214215

215-
filter.v0 = ['alice','bob']
216+
filter.v0 = ['alice', 'bob']
216217
e.load_filtered_policy(filter)
217218
self.assertTrue(e.enforce('alice', 'data1', 'read'))
218-
self.assertFalse(e.enforce('alice','data1','write'))
219+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
219220
self.assertFalse(e.enforce('alice', 'data2', 'read'))
220221
self.assertFalse(e.enforce('alice', 'data2', 'write'))
221222
self.assertFalse(e.enforce('bob', 'data1', 'read'))
222223
self.assertFalse(e.enforce('bob', 'data1', 'write'))
223224
self.assertFalse(e.enforce('bob', 'data2', 'read'))
224225
self.assertTrue(e.enforce('bob', 'data2', 'write'))
225-
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
226-
self.assertFalse(e.enforce('data2_admin', 'data2','write'))
226+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
227+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
227228

228229
filter.v0 = []
229230
filter.v1 = ['data1']
230231
e.load_filtered_policy(filter)
231232
self.assertTrue(e.enforce('alice', 'data1', 'read'))
232-
self.assertFalse(e.enforce('alice','data1','write'))
233+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
233234
self.assertFalse(e.enforce('alice', 'data2', 'read'))
234235
self.assertFalse(e.enforce('alice', 'data2', 'write'))
235236
self.assertFalse(e.enforce('bob', 'data1', 'read'))
236237
self.assertFalse(e.enforce('bob', 'data1', 'write'))
237238
self.assertFalse(e.enforce('bob', 'data2', 'read'))
238239
self.assertFalse(e.enforce('bob', 'data2', 'write'))
239-
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
240-
self.assertFalse(e.enforce('data2_admin', 'data2','write'))
240+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
241+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
241242

242243
filter.v1 = ['data2']
243244
e.load_filtered_policy(filter)
244245
self.assertFalse(e.enforce('alice', 'data1', 'read'))
245-
self.assertFalse(e.enforce('alice','data1','write'))
246+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
246247
self.assertFalse(e.enforce('alice', 'data2', 'read'))
247248
self.assertFalse(e.enforce('alice', 'data2', 'write'))
248249
self.assertFalse(e.enforce('bob', 'data1', 'read'))
249250
self.assertFalse(e.enforce('bob', 'data1', 'write'))
250251
self.assertFalse(e.enforce('bob', 'data2', 'read'))
251252
self.assertTrue(e.enforce('bob', 'data2', 'write'))
252-
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
253-
self.assertTrue(e.enforce('data2_admin', 'data2','write'))
253+
self.assertTrue(e.enforce('data2_admin', 'data2', 'read'))
254+
self.assertTrue(e.enforce('data2_admin', 'data2', 'write'))
254255

255256
filter.v1 = []
256257
filter.v2 = ['read']
257258
e.load_filtered_policy(filter)
258259
self.assertTrue(e.enforce('alice', 'data1', 'read'))
259-
self.assertFalse(e.enforce('alice','data1','write'))
260+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
260261
self.assertFalse(e.enforce('alice', 'data2', 'read'))
261262
self.assertFalse(e.enforce('alice', 'data2', 'write'))
262263
self.assertFalse(e.enforce('bob', 'data1', 'read'))
263264
self.assertFalse(e.enforce('bob', 'data1', 'write'))
264265
self.assertFalse(e.enforce('bob', 'data2', 'read'))
265266
self.assertFalse(e.enforce('bob', 'data2', 'write'))
266-
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
267-
self.assertFalse(e.enforce('data2_admin', 'data2','write'))
267+
self.assertTrue(e.enforce('data2_admin', 'data2', 'read'))
268+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
268269

269270
filter.v2 = ['write']
270271
e.load_filtered_policy(filter)
271272
self.assertFalse(e.enforce('alice', 'data1', 'read'))
272-
self.assertFalse(e.enforce('alice','data1','write'))
273+
self.assertFalse(e.enforce('alice', 'data1', 'write'))
273274
self.assertFalse(e.enforce('alice', 'data2', 'read'))
274275
self.assertFalse(e.enforce('alice', 'data2', 'write'))
275276
self.assertFalse(e.enforce('bob', 'data1', 'read'))
276277
self.assertFalse(e.enforce('bob', 'data1', 'write'))
277278
self.assertFalse(e.enforce('bob', 'data2', 'read'))
278279
self.assertTrue(e.enforce('bob', 'data2', 'write'))
279-
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
280-
self.assertTrue(e.enforce('data2_admin', 'data2','write'))
280+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
281+
self.assertTrue(e.enforce('data2_admin', 'data2', 'write'))

0 commit comments

Comments
 (0)