Skip to content

Commit

Permalink
added middleware and methods to handle session
Browse files Browse the repository at this point in the history
  • Loading branch information
Cijin committed Oct 17, 2024
1 parent 4d60710 commit 6d22f79
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cmd/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func registerRoutes(r chi.Router, h *handlers.HttpHandler) {
r.Get("/", h.HomePage)
r.Get("/", h.Authorize(h.HomePage))

r.Get("/login/{provider}", h.Login)
r.Get("/auth/{provider}/callback", h.AuthCallback)
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.23.1

require (
github.com/a-h/templ v0.2.778
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/coreos/go-oidc/v3 v3.11.0
github.com/go-chi/chi v1.5.5
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/cors v1.2.1
Expand All @@ -22,14 +22,13 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/libsql/sqlite-antlr4-parser v0.0.0-20240327125255-dbf53b6cbf06 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/pquerna/cachecontrol v0.2.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
golang.org/x/sync v0.8.0 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
)
15 changes: 4 additions & 11 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk=
github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
Expand All @@ -21,6 +20,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs=
github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
Expand Down Expand Up @@ -49,16 +50,12 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k=
github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI=
github.com/pressly/goose/v3 v3.22.1 h1:2zICEfr1O3yTP9BRZMGPj7qFxQ+ik6yeo+z1LMuioLc=
github.com/pressly/goose/v3 v3.22.1/go.mod h1:xtMpbstWyCpyH+0cxLTMCENWBG+0CSxvTsXhW95d5eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tursodatabase/go-libsql v0.0.0-20240916111504-922dfa87e1e6 h1:bFxO2fsY5mHZRrVvhmrAo/O8Agi9HDAIMmmOClZMrkQ=
Expand All @@ -75,10 +72,6 @@ golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
Expand Down
25 changes: 17 additions & 8 deletions internal/database/sessions.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions internal/sql/queries/sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ DO UPDATE SET
RETURNING *;

-- name: GetSession :one
SELECT refresh_token FROM sessions
SELECT * FROM sessions
WHERE email=?
LIMIT 1;

Expand All @@ -21,7 +21,7 @@ UPDATE sessions
SET
refresh_token=?,
access_token=?
WHERE email=?;
WHERE id=?;

-- name: DeleteSession :exec
DELETE FROM sessions
Expand Down
26 changes: 25 additions & 1 deletion pkg/authenticator/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package authenticator

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (a *Authenticator) Authenticate(r *http.Request, opts ...oauth2.AuthCodeOpt
return nil, user, errors.New("token recieved is invalid")
}

idToken, err := provider.VerifyIssuer(r.Context(), token)
idToken, err := provider.VerifyIdToken(r.Context(), token)
if err != nil {
return nil, user, err
}
Expand All @@ -77,3 +78,26 @@ func (a *Authenticator) Authenticate(r *http.Request, opts ...oauth2.AuthCodeOpt

return token, user, nil
}

func (a *Authenticator) VerifyIdToken(ctx context.Context, providerName string, token *oauth2.Token) (data.SessionUser, error) {
provider, ok := a.providers[providerName]
if !ok {
return data.SessionUser{}, fmt.Errorf("Provider:'%s' is not a registered provider", providerName)
}

idToken, err := provider.VerifyIdToken(ctx, token)
if err != nil {
return data.SessionUser{}, err
}

return provider.GetUserInfo(idToken)
}

func (a *Authenticator) RefreshToken(ctx context.Context, providerName, refreshToken string) (*oauth2.Token, error) {
provider, ok := a.providers[providerName]
if !ok {
return nil, fmt.Errorf("Provider:'%s' is not a registered provider", providerName)
}

return provider.RefreshToken(ctx, refreshToken)
}
5 changes: 3 additions & 2 deletions pkg/authenticator/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (

"shave/pkg/data"

"github.com/coreos/go-oidc"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)

type Provider interface {
GetName() string
GetAuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
ExchangeCode(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
VerifyIssuer(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error)
VerifyIdToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error)
GetUserInfo(idToken *oidc.IDToken) (data.SessionUser, error)
RefreshToken(ctx context.Context, refreshToken string) (*oauth2.Token, error)
}
11 changes: 9 additions & 2 deletions pkg/authenticator/providers/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

"shave/pkg/data"

"github.com/coreos/go-oidc"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
Expand Down Expand Up @@ -61,7 +61,7 @@ func (p *Provider) ExchangeCode(ctx context.Context, code string, opts ...oauth2
return p.config.Exchange(ctx, code, opts...)
}

func (p *Provider) VerifyIssuer(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
func (p *Provider) VerifyIdToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, errors.New("no id_token field in oauth2 token")
Expand Down Expand Up @@ -115,3 +115,10 @@ func (p *Provider) GetUserInfo(idToken *oidc.IDToken) (data.SessionUser, error)

return user, nil
}

func (p *Provider) RefreshToken(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
token := &oauth2.Token{RefreshToken: refreshToken}
ts := p.config.TokenSource(ctx, token)

return ts.Token()
}
12 changes: 12 additions & 0 deletions pkg/data/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ func (su SessionUser) Valid(ctx context.Context) Problems {
return problems
}

func (su SessionUser) IsSessionEqual(cmp SessionUser) bool {
if su.Sub != cmp.Sub {
return false
}

if su.Email != cmp.Email {
return false
}

return true
}

type SessionVerifier struct {
Verifier string
State uuid.UUID
Expand Down
102 changes: 102 additions & 0 deletions pkg/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"database/sql"
"errors"
"fmt"
"log/slog"
"net/http"
Expand All @@ -13,11 +14,111 @@ import (
"shave/views/home"
"shave/views/unauthorized"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/oauth2"
)

type authedHandler func(w http.ResponseWriter, r *http.Request, sessionUser data.SessionUser)

func (h *HttpHandler) CheckAuthoziation(w http.ResponseWriter, r *http.Request) (data.SessionUser, error) {
var user data.SessionUser

session, err := h.store.GetSession(r)
if err != nil {
return user, err
}

user, err = h.store.GetSessionUser(r)
if err != nil {
return user, err
}

// TODO: this does not work without metadata in the token
// save session id and check saved access token instead??
idTokenUserInfo, err := h.authenticator.VerifyIdToken(r.Context(), session.Provider, &oauth2.Token{AccessToken: session.AccessToken, Expiry: session.Expiry})
if err != nil {
if _, ok := err.(*oidc.TokenExpiredError); ok {
return h.refreshToken(w, r, user, session)
}

return data.SessionUser{}, err
}

if !user.IsSessionEqual(idTokenUserInfo) {
return data.SessionUser{}, errors.New("session info does not match id token info")
}

return user, nil
}

func (h *HttpHandler) Authorize(next authedHandler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionUser, err := h.CheckAuthoziation(w, r)
if err != nil {
slog.Error("Session or User data is malformed or non existent", "AUTHORIZE_ERROR", err)

if r.URL.Path == "/" {
h.HomePage(w, r, data.SessionUser{})
} else {
http.Redirect(w, r, "/", http.StatusSeeOther)
}
return
}

if r.URL.Path == "/" {
renderComponent(w, r, home.SessionedHome(sessionUser))
return
}

next(w, r, sessionUser)
})
}

func (h *HttpHandler) refreshToken(w http.ResponseWriter, r *http.Request, u data.SessionUser, s data.Session) (data.SessionUser, error) {
var user data.SessionUser

savedSession, err := h.dbQueries.GetSession(r.Context(), u.Email)
if err != nil {
return user, err
}

if savedSession.UserID != u.UserId.String() {
return user, fmt.Errorf("user ID from database=%s, does not match session=%s", savedSession.UserID, u.UserId.String())
}

if savedSession.Provider != s.Provider {
return user, fmt.Errorf("provider from database=%s does not match session=%s", savedSession.Provider, s.Provider)
}

token, err := h.authenticator.RefreshToken(r.Context(), s.Provider, savedSession.RefreshToken)
if err != nil {
return user, err
}

updateSessionParams := database.UpdateSessionParams{
RefreshToken: token.RefreshToken,
AccessToken: token.AccessToken,
ID: savedSession.ID,
}

err = h.dbQueries.UpdateSession(r.Context(), updateSessionParams)
if err != nil {
return user, err
}

s.AccessToken = token.AccessToken
s.Expiry = token.Expiry

err = h.store.SaveSession(w, r, s)
if err != nil {
return user, err
}

return u, nil
}

func (h *HttpHandler) Login(w http.ResponseWriter, r *http.Request) {
sessionVerifier, err := h.store.SaveSessionVerfier(w, r)
if err != nil {
Expand Down Expand Up @@ -139,5 +240,6 @@ func (h *HttpHandler) AuthCallback(w http.ResponseWriter, r *http.Request) {
return
}

w.Header().Set("HX-Push-Url", "/")
renderComponent(w, r, home.SessionedHome(sessionUser))
}
Loading

0 comments on commit 6d22f79

Please sign in to comment.