1
- from flask_login import LoginManager
1
+ from flask_login import current_user , LoginManager
2
+ from flask_oauthlib .provider import OAuth2Provider
2
3
from flask_sqlalchemy import SQLAlchemy
4
+
3
5
from sqlalchemy import UniqueConstraint
4
6
from sqlalchemy .orm import backref
5
7
from sqlalchemy .ext .hybrid import hybrid_property
6
8
9
+ from datetime import datetime , timedelta
7
10
import util
8
11
9
12
db = SQLAlchemy ()
10
13
login_manager = LoginManager ()
14
+ oauth = OAuth2Provider ()
11
15
12
16
13
17
class User (db .Model ):
@@ -92,6 +96,35 @@ class Event(db.Model):
92
96
link = db .Column (db .Unicode (length = 256 ))
93
97
removed = db .Column (db .Boolean , default = False )
94
98
99
+ # OAuth2 stuff
100
+ client_id = db .Column (db .String (40 ), unique = True )
101
+ client_secret = db .Column (db .String (55 ), unique = True , index = True , nullable = False )
102
+ is_confidential = db .Column (db .Boolean )
103
+ _redirect_uris = db .Column (db .Text )
104
+ _default_scopes = db .Column (db .Text )
105
+
106
+ @property
107
+ def client_type (self ):
108
+ if self .is_confidential :
109
+ return 'confidential'
110
+ return 'public'
111
+
112
+ @property
113
+ def redirect_uris (self ):
114
+ if self ._redirect_uris :
115
+ return self ._redirect_uris .split ()
116
+ return []
117
+
118
+ @property
119
+ def default_redirect_uri (self ):
120
+ return self .redirect_uris [0 ]
121
+
122
+ @property
123
+ def default_scopes (self ):
124
+ if self ._default_scopes :
125
+ return self ._default_scopes .split ()
126
+ return []
127
+
95
128
96
129
class EventVote (db .Model ):
97
130
__tablename__ = 'eventvotes'
@@ -102,3 +135,125 @@ class EventVote(db.Model):
102
135
event = db .relationship ('Event' , backref = 'votes' )
103
136
direction = db .Column (db .Boolean )
104
137
__table_args__ = (UniqueConstraint ('user_id' , 'event_id' , name = 'eventvote_user_event_uc' ),)
138
+
139
+
140
+ class Grant (db .Model ):
141
+ id = db .Column (db .Integer , primary_key = True )
142
+ user_id = db .Column (db .Integer , db .ForeignKey ('users.id' , ondelete = 'CASCADE' ))
143
+ user = db .relationship ('User' )
144
+
145
+ client_id = db .Column (db .String (40 ), db .ForeignKey ('events.client_id' ), nullable = False )
146
+ client = db .relationship ('Event' )
147
+
148
+ code = db .Column (db .String (255 ), index = True , nullable = False )
149
+
150
+ redirect_uri = db .Column (db .String (255 ))
151
+ expires = db .Column (db .DateTime )
152
+
153
+ _scopes = db .Column (db .Text )
154
+
155
+ def delete (self ):
156
+ db .session .delete (self )
157
+ db .session .commit ()
158
+ return self
159
+
160
+ @property
161
+ def scopes (self ):
162
+ if self ._scopes :
163
+ return self ._scopes .split ()
164
+ return []
165
+
166
+
167
+ class Token (db .Model ):
168
+ id = db .Column (db .Integer , primary_key = True )
169
+ client_id = db .Column (
170
+ db .String (40 ), db .ForeignKey ('events.client_id' ),
171
+ nullable = False ,
172
+ )
173
+ client = db .relationship ('Event' )
174
+
175
+ user_id = db .Column (
176
+ db .Integer , db .ForeignKey ('users.id' )
177
+ )
178
+ user = db .relationship ('User' )
179
+
180
+ token_type = db .Column (db .String (40 ))
181
+
182
+ access_token = db .Column (db .String (255 ), unique = True )
183
+ refresh_token = db .Column (db .String (255 ), unique = True )
184
+ expires = db .Column (db .DateTime )
185
+ _scopes = db .Column (db .Text )
186
+
187
+ def delete (self ):
188
+ db .session .delete (self )
189
+ db .session .commit ()
190
+ return self
191
+
192
+ @property
193
+ def scopes (self ):
194
+ if self ._scopes :
195
+ return self ._scopes .split ()
196
+ return []
197
+
198
+ def get_current_user ():
199
+ if current_user :
200
+ return current_user
201
+ return None
202
+
203
+
204
+ @oauth .clientgetter
205
+ def load_client (client_id ):
206
+ return Event .query .filter_by (client_id = client_id ).first ()
207
+
208
+
209
+ @oauth .grantgetter
210
+ def load_grant (client_id , code ):
211
+ return Grant .query .filter_by (client_id = client_id , code = code ).first ()
212
+
213
+ @oauth .grantsetter
214
+ def save_grant (client_id , code , request , * args , ** kwargs ):
215
+ expires = datetime .utcnow () + timedelta (seconds = 100 )
216
+ grant = Grant (
217
+ client_id = client_id ,
218
+ code = code ['code' ],
219
+ redirect_uri = request .redirect_uri ,
220
+ _scopes = ' ' .join (request .scopes ),
221
+ user = get_current_user (),
222
+ expires = expires
223
+ )
224
+ db .session .add (grant )
225
+ db .session .commit ()
226
+ return grant
227
+
228
+
229
+ @oauth .tokengetter
230
+ def load_token (access_token = None , refresh_token = None ):
231
+ if access_token :
232
+ return Token .query .filter_by (access_token = access_token ).first ()
233
+ elif refresh_token :
234
+ return Token .query .filter_by (refresh_token = refresh_token ).first ()
235
+
236
+
237
+ @oauth .tokensetter
238
+ def save_token (token , request , * args , ** kwargs ):
239
+ toks = Token .query .filter_by (client_id = request .client .client_id ,
240
+ user_id = request .user .id )
241
+ # make sure that every client has only one token connected to a user
242
+ for t in toks :
243
+ db .session .delete (t )
244
+
245
+ expires_in = token .get ('expires_in' )
246
+ expires = datetime .utcnow () + timedelta (seconds = expires_in )
247
+
248
+ tok = Token (
249
+ access_token = token ['access_token' ],
250
+ refresh_token = token ['refresh_token' ],
251
+ token_type = token ['token_type' ],
252
+ _scopes = token ['scope' ],
253
+ expires = expires ,
254
+ client_id = request .client .client_id ,
255
+ user_id = request .user .id ,
256
+ )
257
+ db .session .add (tok )
258
+ db .session .commit ()
259
+ return tok
0 commit comments