11use std:: sync:: { Arc , OnceLock } ;
22
3- use axum:: { extract:: State , http:: StatusCode , routing:: post, Json , Router } ;
3+ use axum:: {
4+ extract:: State ,
5+ http:: { HeaderMap , StatusCode } ,
6+ routing:: post,
7+ Json , Router ,
8+ } ;
49use jsonwebtoken:: { get_current_timestamp, DecodingKey , EncodingKey , Validation } ;
510use log:: { info, warn} ;
611use ring:: rand:: { SecureRandom , SystemRandom } ;
712use serde:: { Deserialize , Serialize } ;
13+ use serde_repr:: { Deserialize_repr , Serialize_repr } ;
814
915use crate :: { util, AppState } ;
1016
1117#[ derive( Deserialize , Clone ) ]
1218pub struct AuthConfig {
1319 route : String ,
20+ refresh_subroute : String ,
1421 secret_path : String ,
15- valid_secs : u64 ,
22+ valid_secs_refresh : u64 ,
23+ valid_secs_session : u64 ,
1624}
1725
1826#[ derive( Deserialize ) ]
@@ -21,11 +29,19 @@ pub struct AuthRequest {
2129 password : String ,
2230}
2331
32+ #[ repr( u8 ) ]
33+ #[ derive( Deserialize_repr , Serialize_repr , PartialEq , Eq ) ]
34+ pub enum TokenKind {
35+ Refresh = 0 ,
36+ Session = 1 ,
37+ }
38+
2439#[ derive( Deserialize , Serialize ) ]
2540pub struct Claims {
26- sub : String , // account id as a string
27- crt : u64 , // creation timestamp in UTC
28- exp : u64 , // expiration timestamp in UTC
41+ sub : String , // account id as a string
42+ crt : u64 , // creation timestamp in UTC
43+ exp : u64 , // expiration timestamp in UTC
44+ kind : TokenKind , // kind of token
2945}
3046
3147static SECRET_KEY : OnceLock < Vec < u8 > > = OnceLock :: new ( ) ;
@@ -55,37 +71,48 @@ pub fn register(
5571 rng : & SystemRandom ,
5672) -> Router < Arc < AppState > > {
5773 let route = & config. route ;
74+ let refresh_route = util:: get_subroute ( route, & config. refresh_subroute ) ;
5875 info ! ( "Registering auth route @ {}" , route) ;
76+ info ! ( "\t Refresh route @ {}" , refresh_route) ;
5977 check_secret ( & config. secret_path , rng) ;
60- routes. route ( route, post ( do_auth) )
78+ routes
79+ . route ( route, post ( do_auth) )
80+ . route ( & refresh_route, post ( do_refresh) )
6181}
6282
63- fn gen_jwt ( account_id : i64 , valid_secs : u64 ) -> Result < String , String > {
83+ fn gen_jwt ( auth_config : & AuthConfig , account_id : i64 , kind : TokenKind ) -> Result < String , String > {
6484 let secret = SECRET_KEY . get ( ) . unwrap ( ) ;
6585 let key = EncodingKey :: from_secret ( secret) ;
86+
87+ let valid_secs = match kind {
88+ TokenKind :: Refresh => auth_config. valid_secs_refresh ,
89+ TokenKind :: Session => auth_config. valid_secs_session ,
90+ } ;
91+
6692 let crt = get_current_timestamp ( ) ;
6793 let exp = crt + valid_secs;
6894 let claims = Claims {
6995 sub : account_id. to_string ( ) ,
7096 crt,
7197 exp,
98+ kind,
7299 } ;
100+
73101 jsonwebtoken:: encode ( & jsonwebtoken:: Header :: default ( ) , & claims, & key)
74102 . map_err ( |e| format ! ( "JWT error: {}" , e) )
75103}
76104
77105fn get_validator ( account_id : Option < i64 > ) -> Validation {
78106 let mut validation = Validation :: default ( ) ;
79107 // required claims
80- validation. required_spec_claims . insert ( "crt" . to_string ( ) ) ;
81- validation. required_spec_claims . insert ( "exp" . to_string ( ) ) ;
82108 validation. required_spec_claims . insert ( "sub" . to_string ( ) ) ;
109+ validation. required_spec_claims . insert ( "exp" . to_string ( ) ) ;
83110 // ensure account ID matches if passed in
84111 validation. sub = account_id. map ( |id| id. to_string ( ) ) ;
85112 validation
86113}
87114
88- pub fn validate_jwt ( jwt : & str ) -> Result < i64 , String > {
115+ pub fn validate_jwt ( jwt : & str , kind : TokenKind ) -> Result < i64 , String > {
89116 let Some ( secret) = SECRET_KEY . get ( ) else {
90117 return Err ( "Auth module not initialized" . to_string ( ) ) ;
91118 } ;
@@ -102,6 +129,10 @@ pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
102129 return Err ( "Expired JWT" . to_string ( ) ) ;
103130 }
104131
132+ if token. claims . kind != kind {
133+ return Err ( "Bad token kind" . to_string ( ) ) ;
134+ }
135+
105136 match token. claims . sub . parse ( ) {
106137 Ok ( id) => Ok ( id) ,
107138 Err ( e) => Err ( format ! ( "Bad account ID: {}" , e) ) ,
@@ -118,8 +149,11 @@ async fn do_auth(
118149 warn ! ( "Auth error: {}" , e) ;
119150 ( StatusCode :: UNAUTHORIZED , "Invalid credentials" . to_string ( ) )
120151 } ) ?;
121- let valid_secs = app. config . auth . as_ref ( ) . unwrap ( ) . valid_secs ;
122- match gen_jwt ( account_id, valid_secs) {
152+ match gen_jwt (
153+ app. config . auth . as_ref ( ) . unwrap ( ) ,
154+ account_id,
155+ TokenKind :: Refresh ,
156+ ) {
123157 Ok ( jwt) => Ok ( jwt) ,
124158 Err ( e) => {
125159 warn ! ( "Auth error: {}" , e) ;
@@ -130,3 +164,30 @@ async fn do_auth(
130164 }
131165 }
132166}
167+
168+ async fn do_refresh (
169+ State ( app) : State < Arc < AppState > > ,
170+ headers : HeaderMap ,
171+ ) -> Result < String , ( StatusCode , String ) > {
172+ assert ! ( app. is_tls) ;
173+ let db = app. db . lock ( ) . await ;
174+ // TODO validate the refresh token against the last password reset timestamp
175+ let account_id = match util:: validate_authed_request ( & headers, TokenKind :: Refresh ) {
176+ Ok ( id) => id,
177+ Err ( e) => return Err ( ( StatusCode :: UNAUTHORIZED , e) ) ,
178+ } ;
179+ match gen_jwt (
180+ app. config . auth . as_ref ( ) . unwrap ( ) ,
181+ account_id,
182+ TokenKind :: Session ,
183+ ) {
184+ Ok ( jwt) => Ok ( jwt) ,
185+ Err ( e) => {
186+ warn ! ( "Refresh error: {}" , e) ;
187+ Err ( (
188+ StatusCode :: INTERNAL_SERVER_ERROR ,
189+ "Server error" . to_string ( ) ,
190+ ) )
191+ }
192+ }
193+ }
0 commit comments