From b3ea6b39a9a549f72286ef8f01940358a95d4933 Mon Sep 17 00:00:00 2001 From: JP Robinson Date: Tue, 24 Sep 2019 11:53:48 -0400 Subject: [PATCH] [auth/gcp] Adding basic google web auth flow capabilities (#231) * initial WIP * removing IAM verifier, cleaning up verify logic * docs * hooking in custom exception func * renaming file * removing unneeded client * ensuring cookie is cleared * cleaning up var declarations * fixing the logger * Forbids instead of redirecting * cleaned up verify logic * state cleanup * simplifying logic * avoiding a verify of an empty token * adding some initial authenticator tests, callback still needs coverage * adding some callback tests * making notes in README --- README.md | 2 + auth/gcp/authenticator.go | 413 +++++++++++++++++++++++++++++++++ auth/gcp/authenticator_test.go | 399 +++++++++++++++++++++++++++++++ auth/keys.go | 8 +- auth/keys_test.go | 6 +- auth/verify.go | 10 +- auth/verify_test.go | 12 +- 7 files changed, 836 insertions(+), 14 deletions(-) create mode 100644 auth/gcp/authenticator.go create mode 100644 auth/gcp/authenticator_test.go diff --git a/README.md b/README.md index 0887931e1..249c998f0 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ The `auth/gcp` package provides 2 Google Cloud Platform based `auth.PublicKeySou * The "Identity" key source and token source rely on GCP's [identity JWT mechanism for asserting instance identities](https://cloud.google.com/compute/docs/instances/verifying-instance-identity). This is the preferred method for asserting instance identity on GCP. * The "IAM" key source and token source rely on GCP's IAM services for [signing](https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts/signJwt) and [verifying JWTs](https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts.keys/get). This method can be used outside of GCP, if needed and can provide a bridge for users transitioning from the 1st generation App Engine (where Identity tokens are not available) runtime to the 2nd. +The `auth/gcp` package also includes an `Authenticator`, which encapsulates a Google Identity verifier and [`oauth2`](https://godoc.org/golang.org/x/oauth2) credentials [to manage a basic web auth flow.](https://developers.google.com/identity/sign-in/web/backend-auth#verify-the-integrity-of-the-id-token) + #### [`config`](https://godoc.org/github.com/NYTimes/gizmo/config) diff --git a/auth/gcp/authenticator.go b/auth/gcp/authenticator.go new file mode 100644 index 000000000..ef7fd6523 --- /dev/null +++ b/auth/gcp/authenticator.go @@ -0,0 +1,413 @@ +package gcp + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + kms "cloud.google.com/go/kms/apiv1" + "github.com/NYTimes/gizmo/auth" + "github.com/go-kit/kit/log" + "github.com/pkg/errors" + "golang.org/x/oauth2" + kmsv1 "google.golang.org/genproto/googleapis/cloud/kms/v1" +) + +type ( + // Authenticator leans on Google's OAuth user flow to capture a Google Identity JWS + // and use it in a local, short lived HTTP cookie. The `Middleware` function manages + // login redirects, OAuth callbacks, dropping the HTTP cookie and adding the JWS + // claims information to the request context. User information and the JWS token can + // be retrieved from the context via GetInfo function. + // The Authenticator can also be used for checking service-to-service authentication + // via an Authorization header containing a Google Identity JWS, which can be + // generated using this package's IdentityTokenSource. + // The user state in the web login flow is encrypted using Google KMS. Ensure the + // service account being used has permissions to encrypt and decrypt. + Authenticator struct { + cfg AuthenticatorConfig + secureCookie bool + cookieDomain string + callbackPath string + + keyName string + keys *kms.KeyManagementClient + verifier *auth.Verifier + } + + // AuthenticatorConfig encapsulates the needs of the Authenticator. + AuthenticatorConfig struct { + // CookieName will be used for the local HTTP cookie name. + CookieName string + + // KMSKeyName is used by a Google KMS client for encrypting and decrypting state + // tokens within the oauth exchange. + KMSKeyName string + // UnsafeState can be used to skip the encryption of the "state" token + // within the auth flow. + UnsafeState bool + + // AuthConfig is used by Authenticator.Middleware and callback to enable the + // Google OAuth flow. + AuthConfig *oauth2.Config + + // HeaderExceptions can optionally be included. Any requests that include any of + // the headers included will skip all Authenticator.Middlware checks and no + // claims information will be added to the context. + // This can be useful for unspoofable headers like Google App Engine's + // "X-AppEngine-*" headers for Google Task Queues. + HeaderExceptions []string + + // CustomExceptionsFunc allows any custom exceptions based on the request. For + // example, looking for specific URIs. Return true if should be allowed. If + // false is returned, normal cookie-based authentication happens. + CustomExceptionsFunc func(context.Context, *http.Request) bool + + // IDConfig will be used to verify the Google Identity JWS when it is inbound + // in the HTTP cookie. + IDConfig IdentityConfig + // IDVerifyFunc allows developers to add their own verification on the user + // claims. For example, one could enable access for anyone with an email domain + // of "@example.com". + IDVerifyFunc func(context.Context, IdentityClaimSet) bool + + // Logger will be used to log any errors encountered during the auth flow. + Logger log.Logger + } +) + +// NewAuthenticator will instantiate a new Authenticator, which can be used for verifying +// a number of authentication styles within the Google Cloud Platform ecosystem. +func NewAuthenticator(ctx context.Context, cfg AuthenticatorConfig) (Authenticator, error) { + ks, err := NewIdentityPublicKeySource(ctx, cfg.IDConfig) + if err != nil { + return Authenticator{}, errors.Wrap(err, "unable to init key source") + } + u, err := url.Parse(cfg.AuthConfig.RedirectURL) + if err != nil { + return Authenticator{}, errors.Wrap(err, "unable to pasrse redirect URL") + } + var keys *kms.KeyManagementClient + if !cfg.UnsafeState { + keys, err = kms.NewKeyManagementClient(ctx) + if err != nil { + return Authenticator{}, errors.Wrap(err, "unable to init KMS client") + } + } + if cfg.Logger == nil { + cfg.Logger = log.NewNopLogger() + } + return Authenticator{ + cfg: cfg, + keys: keys, + cookieDomain: strings.Split(u.Host, ":")[0], + secureCookie: u.Scheme == "https", + callbackPath: u.Path, + verifier: auth.NewVerifier(ks, IdentityClaimsDecoderFunc, + IdentityVerifyFunc(cfg.IDVerifyFunc)), + }, nil +} + +// LogOut can be used to clear an existing session. It will add an HTTP cookie with a -1 +// "MaxAge" to the response to remove the cookie from the logged in user's browser. +func (c Authenticator) LogOut(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: c.cfg.CookieName, + Domain: c.cookieDomain, + Secure: c.secureCookie, + Value: "", + Path: "/", + MaxAge: -1, + Expires: time.Unix(0, 0), + }) +} + +func forbidden(w http.ResponseWriter) { + // stop here here to prevent redirect chaos. + code := http.StatusForbidden + http.Error(w, http.StatusText(code), code) +} + +// Middleware will handle login redirects, OAuth callbacks, header exceptions, custom +// exceptions, verifying inbound Google ID or IAM JWS' within HTTP cookies or +// Authorization headers and, if the user passes all checks, it will add the user claims +// to the inbound request context. +func (c Authenticator) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == c.callbackPath { + c.callbackHandler(w, r) + return + } + + // if one of the 'exception' headers exists, let the request pass through + // this is nice for unspoofable headers like 'X-Appengine-*'. + for _, hdr := range c.cfg.HeaderExceptions { + if r.Header.Get(hdr) != "" { + h.ServeHTTP(w, r) + return + } + } + + // if a custom exception func has been configured, passing its inspection + // will bypass Identity auth. + if c.cfg.CustomExceptionsFunc != nil { + if c.cfg.CustomExceptionsFunc(r.Context(), r) { + h.ServeHTTP(w, r) + return + } + } + + // ***all other endpoints must have a cookie or a header*** + + //////////// + // check for an ID Authorization header + // this is for service-to-service auth/authz + //////////// + token, err := auth.GetAuthorizationToken(r) + if err != nil { + c.cfg.Logger.Log("message", "unable to get header, falling back to cookie", + "error", err) + } + + //////////// + // check for an ID HTTP Cookie + // this is for web-based auth from a user + browser + //////////// + if token == "" { + ck, err := r.Cookie(c.cfg.CookieName) + if err != nil { + c.cfg.Logger.Log("message", "unable to get cookie, redirecting", + "error", err) + } else { + token = ck.Value + } + } + + if token == "" { + c.redirect(w, r) + return + } + + verified, err := c.verifier.Verify(r.Context(), token) + if err != nil { + c.cfg.Logger.Log("message", "id verify cookie failure, redirecting", + "error", err) + c.redirect(w, r) + return + } + + // token existed but was invalid, forbid these requests + if !verified { + forbidden(w) + return + } + + claims, err := decodeClaims(token) + if err != nil { + c.redirect(w, r) + return + } + + // add the user claims to the context and call the handlers below + r = r.WithContext(context.WithValue(r.Context(), claimsKey, claims)) + h.ServeHTTP(w, r) + }) +} + +func (c Authenticator) callbackHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + q := r.URL.Query() + + // verify state + uri, ok := c.verifyState(ctx, q.Get("state")) + if !ok { + forbidden(w) + return + } + + code := q.Get("code") + if strings.TrimSpace(code) == "" { + forbidden(w) + return + } + + token, err := c.cfg.AuthConfig.Exchange(ctx, code) + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to exchange code") + forbidden(w) + return + } + idI := token.Extra("id_token") + if idI == nil { + forbidden(w) + return + } + id, ok := idI.(string) + if !ok { + c.cfg.Logger.Log("message", "id_token was not a string", + "error", "unexpectected type: "+fmt.Sprintf("%T", idI)) + forbidden(w) + return + } + + // they have authenticated, see if we can authorize them + // via the given verifyFunc + verified, err := c.verifier.Verify(r.Context(), id) + if err != nil || !verified { + forbidden(w) + return + } + + // grab claims so we can use the expiration on our cookie + claims, err := decodeClaims(id) + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to decode token") + forbidden(w) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: c.cfg.CookieName, + Secure: c.secureCookie, + Value: id, + Domain: c.cookieDomain, + Expires: time.Unix(claims.Exp, 0), + }) + http.Redirect(w, r, uri, http.StatusTemporaryRedirect) +} + +func (c Authenticator) verifyState(ctx context.Context, state string) (string, bool) { + if state == "" { + return "", false + } + rawState, err := base64.StdEncoding.DecodeString(state) + if err != nil { + return "", false + } + + var data stateData + if c.keys == nil { + err = json.Unmarshal(rawState, &data) + if err != nil { + return "", false + } + return data.verifiedURI() + } + + decRes, err := c.keys.Decrypt(ctx, &kmsv1.DecryptRequest{ + Name: c.cfg.KMSKeyName, + Ciphertext: rawState, + }) + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to decrypt state", + "state", state) + return "", false + } + + err = json.Unmarshal(decRes.Plaintext, &data) + if err != nil { + return "", false + } + return data.verifiedURI() +} + +func (s stateData) verifiedURI() (string, bool) { + return s.URI, timeNow().Before(s.Expiry) +} + +type stateData struct { + Expiry time.Time + URI string + Nonce *[24]byte +} + +func newNonce() (*[24]byte, error) { + nonce := &[24]byte{} + _, err := io.ReadFull(rand.Reader, nonce[:]) + if err != nil { + return nonce, errors.Wrap(err, "unable to generate nonce from rand.Reader") + } + return nonce, nil +} + +func (c Authenticator) redirect(w http.ResponseWriter, r *http.Request) { + uri := r.URL.EscapedPath() + if r.URL.RawQuery != "" { + uri += "?" + r.URL.RawQuery + } + // avoid redirect loops + if strings.HasPrefix(uri, c.cfg.AuthConfig.RedirectURL) { + uri = "/" + } + nonce, err := newNonce() + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to generate nonce") + http.Error(w, "oauth error", http.StatusInternalServerError) + return + } + const stateExpiryMins = 10 + stateData, err := json.Marshal(stateData{ + Expiry: timeNow().Add(stateExpiryMins * time.Minute), + URI: uri, + Nonce: nonce, + }) + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to encode state") + http.Error(w, "oauth error", http.StatusInternalServerError) + return + } + if c.keys != nil { + encRes, err := c.keys.Encrypt(r.Context(), &kmsv1.EncryptRequest{ + Name: c.cfg.KMSKeyName, + Plaintext: stateData, + }) + if err != nil { + c.cfg.Logger.Log("error", err, "message", "unable to encrypt state") + } else { + stateData = encRes.Ciphertext + } + } + state := base64.StdEncoding.EncodeToString(stateData) + + http.Redirect(w, r, c.cfg.AuthConfig.AuthCodeURL(state), + http.StatusTemporaryRedirect) +} + +type key int + +const claimsKey key = 1 + +// GetUserClaims will return the Google identity claim set if it exists in the +// context. This can be used in coordination with the Authenticator.Middleware. +func GetUserClaims(ctx context.Context) (IdentityClaimSet, error) { + var claims IdentityClaimSet + clms := ctx.Value(claimsKey) + if clms == nil { + return claims, errors.New("claims not found") + } + return clms.(IdentityClaimSet), nil +} + +func decodeClaims(token string) (IdentityClaimSet, error) { + var claims IdentityClaimSet + s := strings.Split(token, ".") + if len(s) < 2 { + return claims, errors.New("jws: invalid token received") + } + decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + if err != nil { + return claims, err + } + err = json.Unmarshal(decoded, &claims) + if err != nil { + return claims, err + } + return claims, nil +} diff --git a/auth/gcp/authenticator_test.go b/auth/gcp/authenticator_test.go new file mode 100644 index 000000000..e7905f8e5 --- /dev/null +++ b/auth/gcp/authenticator_test.go @@ -0,0 +1,399 @@ +package gcp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/NYTimes/gizmo/auth" + "github.com/go-kit/kit/log" + "golang.org/x/oauth2" +) + +func TestAuthCallback(t *testing.T) { + timeNow = func() time.Time { return time.Date(2019, 9, 23, 21, 0, 0, 0, time.UTC) } + auth.TimeNow = timeNow + keyServer, authServer := setupAuthenticatorTest(t) + defer keyServer.Close() + defer authServer.Close() + + auth, err := NewAuthenticator(context.Background(), AuthenticatorConfig{ + CookieName: "example-cookie", + IDConfig: IdentityConfig{ + Audience: "http://example.com", + CertURL: keyServer.URL, + }, + IDVerifyFunc: func(_ context.Context, cs IdentityClaimSet) bool { + if cs.Aud != "http://example.com" { + return false + } + return strings.HasPrefix(cs.Email, "auth-example@") + }, + AuthConfig: &oauth2.Config{ + RedirectURL: "http://localhost/oauthcallback", + Endpoint: oauth2.Endpoint{ + AuthURL: authServer.URL, + TokenURL: authServer.URL, + }, + }, + Logger: log.NewJSONLogger(os.Stdout), + UnsafeState: true, + }) + if err != nil { + t.Fatalf("unable to init authenticator: %s", err) + } + + var passedAuth bool + handler := auth.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + passedAuth = true + w.WriteHeader(http.StatusOK) + })) + + // make a call to out callback endpoint that has no state info + r := httptest.NewRequest(http.MethodGet, "/randoendpoint", nil) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if passedAuth { + t.Fatal("request passed the auth layer despite having no known token") + } + + got := w.Result() + if got.StatusCode != http.StatusTemporaryRedirect { + t.Fatalf("expected to be get a 307 but got a status of %d instead", + got.StatusCode) + } + + // try to get the callback to play nice with an added state + r = httptest.NewRequest(http.MethodGet, + "/oauthcallback?state=eyJFeHBpcnkiOiIyMDE5LTA5LTIzVDE3OjUwOjQxLjYxOTc4Ny0wNDowMCIsIlVSSSI6Ii9yYW5kb2VuZHBvaW50IiwiTm9uY2UiOlsxNzUsOTIsMjUzLDQxLDg5LDIzMSwxNTAsMjQyLDk4LDY0LDY4LDE4NSwyMzMsMTM2LDcyLDIwOCwwLDIxLDIzLDg0LDEyMywxMzUsMTM5LDk2XX0=&code=XYZ", + nil) + + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if passedAuth { + t.Fatal("request passed the auth layer despite having no known token") + } + + got = w.Result() + + if got.StatusCode != http.StatusTemporaryRedirect { + t.Fatalf("expected to be get a 200 OK but got a status of %d instead", + got.StatusCode) + } + + gotCookie := got.Header.Get("Set-Cookie") + if gotCookie == "" { + t.Fatal("expected cookie to have been dropped but got none") + } + cookieVals := strings.Split(strings.Split(gotCookie, "; ")[0], "=") + if len(cookieVals) != 2 { + t.Fatalf("cookie has unexpected format: %q", gotCookie) + } + if cookieVals[1] != testAuthToken { + t.Fatalf("expected testAuthToken (%q), got %q", testAuthToken, cookieVals[1]) + } +} + +func TestAuthenticatorTokenReject(t *testing.T) { + timeNow = func() time.Time { + return time.Date(2019, 9, 23, 22, 0, 0, 0, time.UTC) + } + auth.TimeNow = timeNow + + keyServer, authServer := setupAuthenticatorTest(t) + defer keyServer.Close() + defer authServer.Close() + + auth, err := NewAuthenticator(context.Background(), AuthenticatorConfig{ + CookieName: "example-cookie", + IDConfig: IdentityConfig{ + Audience: "http://example.com", + CertURL: keyServer.URL, + }, + IDVerifyFunc: func(_ context.Context, cs IdentityClaimSet) bool { + return false // reject _all_ the things + }, + AuthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: authServer.URL, + TokenURL: authServer.URL, + }, + }, + Logger: log.NewJSONLogger(os.Stdout), + UnsafeState: true, + }) + if err != nil { + t.Fatalf("unable to init authenticator: %s", err) + } + + var passedAuth bool + handler := auth.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + passedAuth = true + w.WriteHeader(http.StatusOK) + })) + + // add our known token to the outbound request + r := httptest.NewRequest(http.MethodGet, "/bobloblaw", nil) + r.Header.Set("Authorization", "Bearer "+testAuthToken) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if passedAuth { + t.Fatal("request passed the auth layer but all requests should be rejected") + } + + if got := w.Result(); got.StatusCode != http.StatusForbidden { + t.Fatalf("expected to be get a 403 Forbidden but got a status of %d instead", + got.StatusCode) + } +} + +func TestAuthenticatorTokenSuccess(t *testing.T) { + timeNow = func() time.Time { + return time.Date(2019, 9, 23, 22, 0, 0, 0, time.UTC) + } + auth.TimeNow = timeNow + + keyServer, authServer := setupAuthenticatorTest(t) + defer keyServer.Close() + defer authServer.Close() + + auth, err := NewAuthenticator(context.Background(), AuthenticatorConfig{ + CookieName: "example-cookie", + IDConfig: IdentityConfig{ + Audience: "http://example.com", + CertURL: keyServer.URL, + }, + IDVerifyFunc: func(_ context.Context, cs IdentityClaimSet) bool { + if cs.Aud != "http://example.com" { + return false + } + return strings.HasPrefix(cs.Email, "auth-example@") + }, + AuthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: authServer.URL, + TokenURL: authServer.URL, + }, + }, + Logger: log.NewJSONLogger(os.Stdout), + UnsafeState: true, + }) + if err != nil { + t.Fatalf("unable to init authenticator: %s", err) + } + + var passedAuth bool + handler := auth.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + passedAuth = true + w.WriteHeader(http.StatusOK) + })) + + // add our known token to the outbound request + r := httptest.NewRequest(http.MethodGet, "/bobloblaw", nil) + r.Header.Set("Authorization", "Bearer "+testAuthToken) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if !passedAuth { + t.Fatal("request did not pass the auth layer despite having known token") + } + + if got := w.Result(); got.StatusCode != http.StatusOK { + t.Fatalf("expected to be get a 200 OK but got a status of %d instead", + got.StatusCode) + } + // reset for next run + passedAuth = false + + // add the same token as a cookie to also verify that plays nice + r = httptest.NewRequest(http.MethodGet, "/bobloblaw", nil) + r.Header.Set("Cookie", "example-cookie="+testAuthToken) + + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if !passedAuth { + t.Fatal("request did not pass the auth layer despite having known token within cookie") + } + + if got := w.Result(); got.StatusCode != http.StatusOK { + t.Fatalf("expected to be get a 200 OK but got a status of %d instead", + got.StatusCode) + } +} + +const testAuthToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6IjBiMGJmMTg2NzQzNDcxYTFlZGNhYzMwNjBkMTI1NmY5ZTQwNTBiYTgiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhdWQiOiJodHRwOi8vZXhhbXBsZS5jb20iLCJhenAiOiJhdXRoLWV4YW1wbGVAbnl0LWdvbGFuZy1kZXYuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJzdWIiOiIxMDMzNTk3OTYyODUxOTI5NzE4NzQiLCJlbWFpbCI6ImF1dGgtZXhhbXBsZUBueXQtZ29sYW5nLWRldi5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJpYXQiOjE1NjkyNzIxNjgsImV4cCI6MTU2OTI3NTc2OH0.YCnNzU8mw_bdHmpmAWjcRc8NKs2A2ugz2XenN3opyEddKl9UxnMx-Y7k3Hd5jIhIZbBLp5_nwUojiWSoWXIYrIG-63MNINUCyoZykxwWMXhQTvTChPk69j0ex0wvwfuR044GrH1SRohYZET5JnlfrBroHjSOK0OqHjpePBp84ezK7EXwnKTgvqTB_lTp5__Xmwguw1DkLKVH9lpnU9RalAdjQZL0_tsK3MWSrVrL8byqP7MyOF6t5Xv-Xrb90feZIuJITPDtNoLvxL-ZXN5B-oGVyBlDK3w6mwTjLV4YQCa5lZKy3SrVHgAa4ucFkZFw0kzCJEnRY_YLkGh7c9eh2w" + +func TestAuthCustomExceptions(t *testing.T) { + timeNow = func() time.Time { + return time.Date(2019, 9, 23, 0, 0, 0, 0, time.UTC) + } + auth.TimeNow = timeNow + + keyServer, authServer := setupAuthenticatorTest(t) + defer keyServer.Close() + defer authServer.Close() + + auth, err := NewAuthenticator(context.Background(), AuthenticatorConfig{ + IDConfig: IdentityConfig{ + Audience: "example.com", + CertURL: keyServer.URL, + }, + CustomExceptionsFunc: func(_ context.Context, r *http.Request) bool { + return r.URL.Path == "/bobloblaw" + }, + AuthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: authServer.URL, + TokenURL: authServer.URL, + }, + }, + UnsafeState: true, + }) + if err != nil { + t.Fatalf("unable to init authenticator: %s", err) + } + + var passedAuth bool + handler := auth.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + passedAuth = true + w.WriteHeader(http.StatusOK) + })) + + // hit once without the special path, expect a redirect/no pass + r := httptest.NewRequest(http.MethodGet, "/xyz", nil) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if passedAuth { + t.Fatal("request passed the auth layer without hitting the special path") + } + + if got := w.Result(); got.StatusCode != http.StatusTemporaryRedirect { + t.Fatalf("expected to be redirected but got a status of %d instead", + got.StatusCode) + } + + // use special path, expect to get through + r = httptest.NewRequest(http.MethodGet, "/bobloblaw", nil) + + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if !passedAuth { + t.Fatal("request did not pass the auth layer despite hitting the special path") + } + + if got := w.Result(); got.StatusCode != http.StatusOK { + t.Fatalf("expected to be get a 200 OK but got a status of %d instead", + got.StatusCode) + } + +} + +func TestAuthHeaderExceptions(t *testing.T) { + timeNow = func() time.Time { + return time.Date(2019, 9, 23, 0, 0, 0, 0, time.UTC) + } + auth.TimeNow = timeNow + + keyServer, authServer := setupAuthenticatorTest(t) + defer keyServer.Close() + defer authServer.Close() + + auth, err := NewAuthenticator(context.Background(), AuthenticatorConfig{ + IDConfig: IdentityConfig{ + Audience: "example.com", + CertURL: keyServer.URL, + }, + HeaderExceptions: []string{"X-EXAMPLE"}, + AuthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: authServer.URL, + TokenURL: authServer.URL, + }, + }, + UnsafeState: true, + }) + if err != nil { + t.Fatalf("unable to init authenticator: %s", err) + } + + var passedAuth bool + handler := auth.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + passedAuth = true + w.WriteHeader(http.StatusOK) + })) + + // hit once without any headers, expect a redirect/no pass + r := httptest.NewRequest(http.MethodGet, "/xyz", nil) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if passedAuth { + t.Fatal("request passed the auth layer without including expected headers") + } + + if got := w.Result(); got.StatusCode != http.StatusTemporaryRedirect { + t.Fatalf("expected to be redirected but got a status of %d instead", + got.StatusCode) + } + + // add headers, expect to get through + r = httptest.NewRequest(http.MethodGet, "/xyz", nil) + r.Header.Add("X-EXAMPLE", "1") + + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if !passedAuth { + t.Fatal("request did not pass the auth layer despite including headers") + } + + if got := w.Result(); got.StatusCode != http.StatusOK { + t.Fatalf("expected to be get a 200 OK but got a status of %d instead", + got.StatusCode) + } +} + +func setupAuthenticatorTest(t *testing.T) (*httptest.Server, *httptest.Server) { + t.Helper() + + keyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(auth.JSONKeyResponse{ + Keys: []*auth.JSONKey{ + { + Use: "sig", + Kty: "RSA", + Kid: "0b0bf186743471a1edcac3060d1256f9e4050ba8", + N: "0s9r8J5G5I77VpYWS-ttQ8GBDZBlxN_TZHl4DJHAi1WzvxQcP0hBPdASNqAnAuXA-ZxMpMtW_ovjhwo1Ncqpofd3c0H5mSzA9nsmmiex3AO7ZbkaGIdOcMYr4ttOFKZJn2giZWsfQuTlMEvcGyghViyy6l7t1-dMyxjbNOAVLVn25PHfWLbtffv-5EXFXt0Bp0wf0JjPghy4xXf3GjqqqaG_pOnmY_g2c6s8NwZG8dLymiqq0sta3URCUzDYnEHfx7Ol-grOYBOg6YjQP-gl0r5_uvB9Vl9jXKz-WcUUqVTuLp6S-CBstsOheUpSjX3vVP48KJIS4DX6NFHgjn8ooQ", + E: "AQAB", + }, + }, + }) + })) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "nah", + "expires_in": 3600, + "scope": "https://www.googleapis.com/auth/userinfo.email", + "token_type": "Bearer", + "id_token": "` + testAuthToken + `" + }`)) + })) + return keyServer, authServer +} diff --git a/auth/keys.go b/auth/keys.go index c06afae07..f78d0faf7 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -55,7 +55,7 @@ type PublicKeySet struct { // Expired will return true if the current key set is expire according to its Expiry // field. func (ks PublicKeySet) Expired() bool { - return timeNow().After(ks.Expiry) + return TimeNow().After(ks.Expiry) } // GetKey will look for the given key ID in the key set and return it, if it exists. @@ -135,7 +135,7 @@ func NewPublicKeySetFromJSON(payload []byte, ttl time.Duration) (PublicKeySet, e } ks = PublicKeySet{ - Expiry: timeNow().Add(ttl), + Expiry: TimeNow().Add(ttl), Keys: map[string]*rsa.PublicKey{}, } @@ -160,4 +160,6 @@ func NewPublicKeySetFromJSON(payload []byte, ttl time.Duration) (PublicKeySet, e return ks, nil } -var timeNow = func() time.Time { return time.Now() } +// TimeNow is used internally to determine the current time. It has been abstracted to +// this global function as a mechanism to help with testing. +var TimeNow = func() time.Time { return time.Now() } diff --git a/auth/keys_test.go b/auth/keys_test.go index 5942dfe40..c4a42e483 100644 --- a/auth/keys_test.go +++ b/auth/keys_test.go @@ -16,7 +16,7 @@ import ( func TestResuseKeySource(t *testing.T) { testTime := time.Date(2018, 10, 29, 12, 0, 0, 0, time.UTC) - timeNow = func() time.Time { return testTime } + TimeNow = func() time.Time { return testTime } firstKeys, err := NewPublicKeySetFromJSON([]byte(testGoogleCerts), 1*time.Second) if err != nil { @@ -47,7 +47,7 @@ func TestResuseKeySource(t *testing.T) { } // move time forward, expire the first keys - timeNow = func() time.Time { return testTime.Add(1500 * time.Millisecond) } + TimeNow = func() time.Time { return testTime.Add(1500 * time.Millisecond) } gotKeys, err = reuser.Get(context.Background()) if err != nil { @@ -153,7 +153,7 @@ func TestKeySetFromURL(t *testing.T) { } for _, test := range tests { - timeNow = func() time.Time { return testTime } + TimeNow = func() time.Time { return testTime } t.Run(test.name, func(t *testing.T) { srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/auth/verify.go b/auth/verify.go index 4311d8c48..473b47ce5 100644 --- a/auth/verify.go +++ b/auth/verify.go @@ -74,7 +74,7 @@ func (c Verifier) VerifyInboundKitContext(ctx context.Context) (bool, error) { // VerifyRequest will pull the token from the "Authorization" header of the inbound // request then decode and verify it. func (c Verifier) VerifyRequest(r *http.Request) (bool, error) { - token, err := parseHeader(r.Header.Get("Authorization")) + token, err := GetAuthorizationToken(r) if err != nil { return false, err } @@ -111,7 +111,7 @@ func (c Verifier) Verify(ctx context.Context, token string) (bool, error) { } claims := clmstr.BaseClaims() - nowUnix := timeNow().Unix() + nowUnix := TimeNow().Unix() if nowUnix < (claims.Iat - c.skewAllowance) { return false, errors.New("invalid issue time") @@ -154,3 +154,9 @@ func parseHeader(hdr string) (string, error) { } return auths[1], nil } + +// GetAuthorizationToken will pull the Authorization header from the given request and +// attempt to retrieve the token within it. +func GetAuthorizationToken(r *http.Request) (string, error) { + return parseHeader(r.Header.Get("Authorization")) +} diff --git a/auth/verify_test.go b/auth/verify_test.go index efbaa2011..712a05976 100644 --- a/auth/verify_test.go +++ b/auth/verify_test.go @@ -61,7 +61,7 @@ func TestVerifyRequest(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - timeNow = func() time.Time { return testTime } + TimeNow = func() time.Time { return testTime } token, err := encode( &jws.Header{Algorithm: "RS256", Typ: "JWT", KeyID: keyID}, @@ -89,7 +89,7 @@ func TestVerifyRequest(t *testing.T) { ks := testKeySource{ keys: PublicKeySet{ - Expiry: timeNow().Add(time.Hour), + Expiry: TimeNow().Add(time.Hour), Keys: map[string]*rsa.PublicKey{ keyID: &prv.PublicKey, }, @@ -163,7 +163,7 @@ func TestVerifyInboundKit(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - timeNow = func() time.Time { return testTime } + TimeNow = func() time.Time { return testTime } token, err := encode( &jws.Header{Algorithm: "RS256", Typ: "JWT", KeyID: keyID}, @@ -191,7 +191,7 @@ func TestVerifyInboundKit(t *testing.T) { ks := testKeySource{ keys: PublicKeySet{ - Expiry: timeNow().Add(time.Hour), + Expiry: TimeNow().Add(time.Hour), Keys: map[string]*rsa.PublicKey{ keyID: &prv.PublicKey, }, @@ -337,7 +337,7 @@ func TestVerify(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - timeNow = func() time.Time { return testTime } + TimeNow = func() time.Time { return testTime } token, err := encode( &jws.Header{Algorithm: "RS256", Typ: "JWT", KeyID: keyID}, @@ -370,7 +370,7 @@ func TestVerify(t *testing.T) { ks := testKeySource{ keys: PublicKeySet{ - Expiry: timeNow().Add(time.Hour), + Expiry: TimeNow().Add(time.Hour), Keys: map[string]*rsa.PublicKey{ kid: &prv.PublicKey, },