@@ -4,11 +4,15 @@ import (
4
4
"context"
5
5
"crypto/rand"
6
6
"encoding/base64"
7
+ "fmt"
7
8
"log"
8
9
"net/http"
10
+ "net/url"
9
11
"os"
10
12
"time"
11
13
14
+ db "openai-api-proxy/db"
15
+
12
16
"github.com/coreos/go-oidc/v3/oidc"
13
17
"golang.org/x/oauth2"
14
18
)
@@ -25,6 +29,11 @@ func Init(mux *http.ServeMux) (a *Auth) {
25
29
// handle error
26
30
log .Println ("Cannot Initiate Provider: " , err )
27
31
}
32
+ var claims * ProviderClaims
33
+ err = provider .Claims (& claims )
34
+ if err != nil {
35
+ log .Println ("Cannot extract provider Claims for LogoutURL" , err )
36
+ }
28
37
29
38
clientId , ok := os .LookupEnv ("CLIENT_ID" )
30
39
if ! ok {
@@ -50,33 +59,65 @@ func Init(mux *http.ServeMux) (a *Auth) {
50
59
Endpoint : provider .Endpoint (),
51
60
52
61
// "openid" is a required scope for OpenID Connect flows.
53
- Scopes : []string {oidc .ScopeOpenID , "profile" , "email" },
62
+ Scopes : []string {oidc .ScopeOpenID , "profile" , "email" , "roles" },
54
63
}
55
64
56
65
var verifier = provider .Verifier (& oidc.Config {ClientID : clientId })
57
66
a = & Auth {
58
67
oauth2Config : oauth2Config ,
59
68
ctx : context .Background (),
60
69
verifier : verifier ,
70
+ provider : provider ,
71
+ claims : claims ,
61
72
}
62
73
mux .HandleFunc ("/callback/" , a .CallbackHandler )
63
74
mux .HandleFunc ("/login/" , a .LoginHandler )
64
75
return a
65
76
}
66
77
78
+ type ProviderClaims struct {
79
+ EndSessionURL string `json:"end_session_endpoint"`
80
+ }
81
+
67
82
type Claims struct {
68
83
Email string `json:"email"`
84
+ Name string `json:"name"`
69
85
Verified bool `json:"email_verified"`
70
86
Sub string `json:"sub"`
71
- Groups []string `json:"groups "`
87
+ Roles []string `json:"roles "`
72
88
}
73
89
74
90
type Auth struct {
75
91
oauth2Config * oauth2.Config
76
92
ctx context.Context
77
93
verifier * oidc.IDTokenVerifier
94
+ provider * oidc.Provider
95
+ claims * ProviderClaims
78
96
}
79
97
98
+ func (a * Auth ) LogoutHandler (w http.ResponseWriter , r * http.Request ) {
99
+ state , _ := r .Cookie ("oauthstate" )
100
+ http .SetCookie (w , & http.Cookie {
101
+ Name : "session_token" ,
102
+ Path : "/" ,
103
+ Value : "" ,
104
+ Expires : time .Unix (0 , 0 ),
105
+ })
106
+ if a .claims .EndSessionURL == "" {
107
+ http .Error (w , "Logout not implemented by Identity Provider" , http .StatusNotImplemented )
108
+ return
109
+ }
110
+ logoutURL , err := url .Parse (a .claims .EndSessionURL )
111
+ if err != nil {
112
+ log .Println ("Error parsing URL: " , err )
113
+ }
114
+ query := logoutURL .Query ()
115
+ query .Set ("state" , state .Value )
116
+ query .Set ("post_logout_redirect_uri" , a .oauth2Config .RedirectURL ) // Not implemented by our IDP of Testing https://github.com/zitadel/zitadel/issues/6615
117
+ logoutURL .RawQuery = query .Encode ()
118
+
119
+ http .Redirect (w , r , logoutURL .String (), http .StatusTemporaryRedirect )
120
+ }
80
121
func (a * Auth ) LoginHandler (w http.ResponseWriter , r * http.Request ) {
81
122
82
123
// Create oauthState cookie
@@ -127,16 +168,55 @@ func (a *Auth) GetClaims(r *http.Request) (*Claims, error) {
127
168
var err error
128
169
// Parse and verify ID Token payload.
129
170
rawIDToken , err := r .Cookie ("session_token" )
171
+ if err != nil {
172
+ return nil , err
173
+ }
130
174
idToken , err := a .verifier .Verify (a .ctx , rawIDToken .Value )
175
+ if err != nil {
176
+ return nil , err
177
+ }
131
178
claims := & Claims {}
132
179
// Extract custom claims
133
- if err = idToken .Claims (claims ); err != nil {
180
+
181
+ if err = idToken .Claims (& claims ); err != nil {
134
182
log .Println ("Error Extracting User-Claim" , err )
135
183
return nil , err
136
184
// handle error
137
185
}
138
186
return claims , nil
139
187
}
188
+
189
+ type NotAdmin struct {
190
+ uid string
191
+ }
192
+
193
+ func (n * NotAdmin ) Error () string {
194
+ return fmt .Sprintf ("User %s has no Admin Role" , n .uid )
195
+ }
196
+
197
+ func (a * Auth ) ValidateAdminSession (w http.ResponseWriter , r * http.Request ) (bool , error ) {
198
+ authenticated := a .ValidateSessionToken (w , r )
199
+ if ! authenticated {
200
+ log .Println ("Not Authenticated" )
201
+ return false , nil
202
+ }
203
+ claims , err := a .GetClaims (r )
204
+ if err != nil {
205
+ log .Println ("Claims not found" )
206
+ return false , err
207
+ }
208
+ db := db .NewDB ()
209
+ defer db .Close ()
210
+ user , err := db .GetUser (claims .Sub )
211
+ if err != nil {
212
+ return false , err
213
+ }
214
+ if user .IsAdmin {
215
+ return true , nil
216
+ }
217
+ return false , & NotAdmin {uid : claims .Sub }
218
+ }
219
+
140
220
func (a * Auth ) ValidateSessionToken (w http.ResponseWriter , r * http.Request ) bool {
141
221
// Parse and verify ID Token payload.
142
222
rawIDToken , err := r .Cookie ("session_token" )
0 commit comments