From 6d22f7995f95850d1872a512a73c761b00b3ce86 Mon Sep 17 00:00:00 2001 From: sea gin Date: Thu, 17 Oct 2024 13:52:09 +0700 Subject: [PATCH] added middleware and methods to handle session --- cmd/server/routes.go | 2 +- go.mod | 5 +- go.sum | 15 +-- internal/database/sessions.sql.go | 25 +++-- internal/sql/queries/sessions.sql | 4 +- pkg/authenticator/main.go | 26 ++++- pkg/authenticator/provider.go | 5 +- pkg/authenticator/providers/google/google.go | 11 +- pkg/data/session.go | 12 +++ pkg/handlers/auth.go | 102 +++++++++++++++++++ pkg/handlers/home.go | 3 +- pkg/store/main.go | 1 - 12 files changed, 179 insertions(+), 32 deletions(-) diff --git a/cmd/server/routes.go b/cmd/server/routes.go index 0b64702..6f4b22f 100644 --- a/cmd/server/routes.go +++ b/cmd/server/routes.go @@ -7,7 +7,7 @@ import ( ) func registerRoutes(r chi.Router, h *handlers.HttpHandler) { - r.Get("/", h.HomePage) + r.Get("/", h.Authorize(h.HomePage)) r.Get("/login/{provider}", h.Login) r.Get("/auth/{provider}/callback", h.AuthCallback) diff --git a/go.mod b/go.mod index 4862f5b..4db094d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.1 require ( github.com/a-h/templ v0.2.778 - github.com/coreos/go-oidc v2.2.1+incompatible + github.com/coreos/go-oidc/v3 v3.11.0 github.com/go-chi/chi v1.5.5 github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/cors v1.2.1 @@ -22,14 +22,13 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/libsql/sqlite-antlr4-parser v0.0.0-20240327125255-dbf53b6cbf06 // indirect github.com/mfridman/interpolate v0.0.2 // indirect - github.com/pquerna/cachecontrol v0.2.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/sync v0.8.0 // indirect - gopkg.in/square/go-jose.v2 v2.6.0 // indirect ) diff --git a/go.sum b/go.sum index b46867a..567b97a 100644 --- a/go.sum +++ b/go.sum @@ -6,9 +6,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk= -github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= @@ -21,6 +20,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs= github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -49,16 +50,12 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= -github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= github.com/pressly/goose/v3 v3.22.1 h1:2zICEfr1O3yTP9BRZMGPj7qFxQ+ik6yeo+z1LMuioLc= github.com/pressly/goose/v3 v3.22.1/go.mod h1:xtMpbstWyCpyH+0cxLTMCENWBG+0CSxvTsXhW95d5eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tursodatabase/go-libsql v0.0.0-20240916111504-922dfa87e1e6 h1:bFxO2fsY5mHZRrVvhmrAo/O8Agi9HDAIMmmOClZMrkQ= @@ -75,10 +72,6 @@ golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= -gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= diff --git a/internal/database/sessions.sql.go b/internal/database/sessions.sql.go index 2786b8b..f93abe2 100644 --- a/internal/database/sessions.sql.go +++ b/internal/database/sessions.sql.go @@ -71,16 +71,25 @@ func (q *Queries) DeleteSession(ctx context.Context, email string) error { } const getSession = `-- name: GetSession :one -SELECT refresh_token FROM sessions +SELECT id, user_id, email, refresh_token, access_token, provider, created_at, updated_at FROM sessions WHERE email=? LIMIT 1 ` -func (q *Queries) GetSession(ctx context.Context, email string) (string, error) { +func (q *Queries) GetSession(ctx context.Context, email string) (Session, error) { row := q.db.QueryRowContext(ctx, getSession, email) - var refresh_token string - err := row.Scan(&refresh_token) - return refresh_token, err + var i Session + err := row.Scan( + &i.ID, + &i.UserID, + &i.Email, + &i.RefreshToken, + &i.AccessToken, + &i.Provider, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err } const updateSession = `-- name: UpdateSession :exec @@ -88,16 +97,16 @@ UPDATE sessions SET refresh_token=?, access_token=? -WHERE email=? +WHERE id=? ` type UpdateSessionParams struct { RefreshToken string AccessToken string - Email string + ID string } func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) error { - _, err := q.db.ExecContext(ctx, updateSession, arg.RefreshToken, arg.AccessToken, arg.Email) + _, err := q.db.ExecContext(ctx, updateSession, arg.RefreshToken, arg.AccessToken, arg.ID) return err } diff --git a/internal/sql/queries/sessions.sql b/internal/sql/queries/sessions.sql index 0193c84..1f43f99 100644 --- a/internal/sql/queries/sessions.sql +++ b/internal/sql/queries/sessions.sql @@ -12,7 +12,7 @@ DO UPDATE SET RETURNING *; -- name: GetSession :one -SELECT refresh_token FROM sessions +SELECT * FROM sessions WHERE email=? LIMIT 1; @@ -21,7 +21,7 @@ UPDATE sessions SET refresh_token=?, access_token=? -WHERE email=?; +WHERE id=?; -- name: DeleteSession :exec DELETE FROM sessions diff --git a/pkg/authenticator/main.go b/pkg/authenticator/main.go index 09e905e..1b4da1d 100644 --- a/pkg/authenticator/main.go +++ b/pkg/authenticator/main.go @@ -1,6 +1,7 @@ package authenticator import ( + "context" "errors" "fmt" "net/http" @@ -65,7 +66,7 @@ func (a *Authenticator) Authenticate(r *http.Request, opts ...oauth2.AuthCodeOpt return nil, user, errors.New("token recieved is invalid") } - idToken, err := provider.VerifyIssuer(r.Context(), token) + idToken, err := provider.VerifyIdToken(r.Context(), token) if err != nil { return nil, user, err } @@ -77,3 +78,26 @@ func (a *Authenticator) Authenticate(r *http.Request, opts ...oauth2.AuthCodeOpt return token, user, nil } + +func (a *Authenticator) VerifyIdToken(ctx context.Context, providerName string, token *oauth2.Token) (data.SessionUser, error) { + provider, ok := a.providers[providerName] + if !ok { + return data.SessionUser{}, fmt.Errorf("Provider:'%s' is not a registered provider", providerName) + } + + idToken, err := provider.VerifyIdToken(ctx, token) + if err != nil { + return data.SessionUser{}, err + } + + return provider.GetUserInfo(idToken) +} + +func (a *Authenticator) RefreshToken(ctx context.Context, providerName, refreshToken string) (*oauth2.Token, error) { + provider, ok := a.providers[providerName] + if !ok { + return nil, fmt.Errorf("Provider:'%s' is not a registered provider", providerName) + } + + return provider.RefreshToken(ctx, refreshToken) +} diff --git a/pkg/authenticator/provider.go b/pkg/authenticator/provider.go index b116194..29ec417 100644 --- a/pkg/authenticator/provider.go +++ b/pkg/authenticator/provider.go @@ -5,7 +5,7 @@ import ( "shave/pkg/data" - "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) @@ -13,6 +13,7 @@ type Provider interface { GetName() string GetAuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string ExchangeCode(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) - VerifyIssuer(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) + VerifyIdToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) GetUserInfo(idToken *oidc.IDToken) (data.SessionUser, error) + RefreshToken(ctx context.Context, refreshToken string) (*oauth2.Token, error) } diff --git a/pkg/authenticator/providers/google/google.go b/pkg/authenticator/providers/google/google.go index b2e920c..f8c4633 100644 --- a/pkg/authenticator/providers/google/google.go +++ b/pkg/authenticator/providers/google/google.go @@ -8,7 +8,7 @@ import ( "shave/pkg/data" - "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) @@ -61,7 +61,7 @@ func (p *Provider) ExchangeCode(ctx context.Context, code string, opts ...oauth2 return p.config.Exchange(ctx, code, opts...) } -func (p *Provider) VerifyIssuer(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { +func (p *Provider) VerifyIdToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, errors.New("no id_token field in oauth2 token") @@ -115,3 +115,10 @@ func (p *Provider) GetUserInfo(idToken *oidc.IDToken) (data.SessionUser, error) return user, nil } + +func (p *Provider) RefreshToken(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + token := &oauth2.Token{RefreshToken: refreshToken} + ts := p.config.TokenSource(ctx, token) + + return ts.Token() +} diff --git a/pkg/data/session.go b/pkg/data/session.go index 4035367..b16a126 100644 --- a/pkg/data/session.go +++ b/pkg/data/session.go @@ -71,6 +71,18 @@ func (su SessionUser) Valid(ctx context.Context) Problems { return problems } +func (su SessionUser) IsSessionEqual(cmp SessionUser) bool { + if su.Sub != cmp.Sub { + return false + } + + if su.Email != cmp.Email { + return false + } + + return true +} + type SessionVerifier struct { Verifier string State uuid.UUID diff --git a/pkg/handlers/auth.go b/pkg/handlers/auth.go index 5c42e20..37fc44a 100644 --- a/pkg/handlers/auth.go +++ b/pkg/handlers/auth.go @@ -2,6 +2,7 @@ package handlers import ( "database/sql" + "errors" "fmt" "log/slog" "net/http" @@ -13,11 +14,111 @@ import ( "shave/views/home" "shave/views/unauthorized" + "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/google/uuid" "golang.org/x/oauth2" ) +type authedHandler func(w http.ResponseWriter, r *http.Request, sessionUser data.SessionUser) + +func (h *HttpHandler) CheckAuthoziation(w http.ResponseWriter, r *http.Request) (data.SessionUser, error) { + var user data.SessionUser + + session, err := h.store.GetSession(r) + if err != nil { + return user, err + } + + user, err = h.store.GetSessionUser(r) + if err != nil { + return user, err + } + + // TODO: this does not work without metadata in the token + // save session id and check saved access token instead?? + idTokenUserInfo, err := h.authenticator.VerifyIdToken(r.Context(), session.Provider, &oauth2.Token{AccessToken: session.AccessToken, Expiry: session.Expiry}) + if err != nil { + if _, ok := err.(*oidc.TokenExpiredError); ok { + return h.refreshToken(w, r, user, session) + } + + return data.SessionUser{}, err + } + + if !user.IsSessionEqual(idTokenUserInfo) { + return data.SessionUser{}, errors.New("session info does not match id token info") + } + + return user, nil +} + +func (h *HttpHandler) Authorize(next authedHandler) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionUser, err := h.CheckAuthoziation(w, r) + if err != nil { + slog.Error("Session or User data is malformed or non existent", "AUTHORIZE_ERROR", err) + + if r.URL.Path == "/" { + h.HomePage(w, r, data.SessionUser{}) + } else { + http.Redirect(w, r, "/", http.StatusSeeOther) + } + return + } + + if r.URL.Path == "/" { + renderComponent(w, r, home.SessionedHome(sessionUser)) + return + } + + next(w, r, sessionUser) + }) +} + +func (h *HttpHandler) refreshToken(w http.ResponseWriter, r *http.Request, u data.SessionUser, s data.Session) (data.SessionUser, error) { + var user data.SessionUser + + savedSession, err := h.dbQueries.GetSession(r.Context(), u.Email) + if err != nil { + return user, err + } + + if savedSession.UserID != u.UserId.String() { + return user, fmt.Errorf("user ID from database=%s, does not match session=%s", savedSession.UserID, u.UserId.String()) + } + + if savedSession.Provider != s.Provider { + return user, fmt.Errorf("provider from database=%s does not match session=%s", savedSession.Provider, s.Provider) + } + + token, err := h.authenticator.RefreshToken(r.Context(), s.Provider, savedSession.RefreshToken) + if err != nil { + return user, err + } + + updateSessionParams := database.UpdateSessionParams{ + RefreshToken: token.RefreshToken, + AccessToken: token.AccessToken, + ID: savedSession.ID, + } + + err = h.dbQueries.UpdateSession(r.Context(), updateSessionParams) + if err != nil { + return user, err + } + + s.AccessToken = token.AccessToken + s.Expiry = token.Expiry + + err = h.store.SaveSession(w, r, s) + if err != nil { + return user, err + } + + return u, nil +} + func (h *HttpHandler) Login(w http.ResponseWriter, r *http.Request) { sessionVerifier, err := h.store.SaveSessionVerfier(w, r) if err != nil { @@ -139,5 +240,6 @@ func (h *HttpHandler) AuthCallback(w http.ResponseWriter, r *http.Request) { return } + w.Header().Set("HX-Push-Url", "/") renderComponent(w, r, home.SessionedHome(sessionUser)) } diff --git a/pkg/handlers/home.go b/pkg/handlers/home.go index 36c2c86..20b9ed0 100644 --- a/pkg/handlers/home.go +++ b/pkg/handlers/home.go @@ -3,9 +3,10 @@ package handlers import ( "net/http" + "shave/pkg/data" "shave/views/home" ) -func (h *HttpHandler) HomePage(w http.ResponseWriter, r *http.Request) { +func (h *HttpHandler) HomePage(w http.ResponseWriter, r *http.Request, _ data.SessionUser) { renderComponent(w, r, home.Index()) } diff --git a/pkg/store/main.go b/pkg/store/main.go index 2db061a..9eedfd9 100644 --- a/pkg/store/main.go +++ b/pkg/store/main.go @@ -30,7 +30,6 @@ func New() (*Store, error) { cookieStore := sessions.NewCookieStore([]byte(secret)) cookieStore.Options.Path = "/" - cookieStore.Options.Secure = true cookieStore.Options.HttpOnly = true cookieStore.Options.SameSite = http.SameSiteLaxMode