Skip to content
This repository was archived by the owner on Jan 24, 2019. It is now read-only.

Add Authorization Bearer <jwt> style headers #534

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ func main() {
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream")
flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)")
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start")
flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests")
Expand Down
116 changes: 104 additions & 12 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type OAuthProxy struct {
PassUserHeaders bool
BasicAuthPassword string
PassAccessToken bool
SetAuthorization bool
PassAuthorization bool
CookieCipher *cookie.Cipher
skipAuthRegex []string
skipAuthPreflight bool
Expand Down Expand Up @@ -163,7 +165,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh)

var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) {
var err error
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
if err != nil {
Expand Down Expand Up @@ -202,6 +204,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
PassUserHeaders: opts.PassUserHeaders,
BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken,
SetAuthorization: opts.SetAuthorization,
PassAuthorization: opts.PassAuthorization,
SkipProviderButton: opts.SkipProviderButton,
CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
Expand Down Expand Up @@ -254,15 +258,92 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
return
}

func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
if len(value) > 4096 {
// Cookies cannot be larger than 4kb
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
}
c := p.makeCookie(req, p.CookieName, value, expiration, now)
if len(c.Value) > 4096-len(p.CookieName) {
return splitCookie(c)
}
return []*http.Cookie{c}
}

func copyCookie(c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
Raw: c.Raw,
Unparsed: c.Unparsed,
}
}

func splitCookie(c *http.Cookie) []*http.Cookie {
if len(c.Value) < 3840 {
return []*http.Cookie{c}
}
cookies := []*http.Cookie{}
valueBytes := []byte(c.Value)
count := 0
for len(valueBytes) > 0 {
new := copyCookie(c)
new.Name = fmt.Sprintf("%s-%d", c.Name, count)
count++
if len(valueBytes) < 3840 {
new.Value = string(valueBytes)
valueBytes = []byte{}
} else {
newValue := valueBytes[:3840]
valueBytes = valueBytes[3840:]
new.Value = string(newValue)
}
cookies = append(cookies, new)
}
return cookies
}

func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
if len(cookies) == 0 {
return nil, fmt.Errorf("Could not load cookie.")
}
if len(cookies) == 1 {
return cookies[0], nil
}
c := copyCookie(cookies[0])
for i := 1; i < len(cookies); i++ {
c.Value += cookies[i].Value
}
c.Name = strings.TrimRight(c.Name, "-0")
return c, nil
}

func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
c, err := req.Cookie(cookieName)
if err == nil {
return c, nil
}
cookies := []*http.Cookie{}
err = nil
count := 0
for err == nil {
var c *http.Cookie
c, err = req.Cookie(fmt.Sprintf("%s-%d", cookieName, count))
if err == nil {
cookies = append(cookies, c)
count++
}
}
return p.makeCookie(req, p.CookieName, value, expiration, now)
if len(cookies) == 0 {
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
}
return joinCookies(cookies)
}

func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
Expand Down Expand Up @@ -292,6 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
}

func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {

http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}

Expand All @@ -300,24 +382,28 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
}

func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clr)
cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
for _, clr := range cookies {
http.SetCookie(rw, clr)
}

// ugly hack because default domain changed
if p.CookieDomain == "" {
clr2 := *clr
if p.CookieDomain == "" && len(cookies) > 0 {
clr2 := *cookies[0]
clr2.Domain = req.Host
http.SetCookie(rw, &clr2)
}
}

func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) {
http.SetCookie(rw, c)
}
}

func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var age time.Duration
c, err := req.Cookie(p.CookieName)
c, err := loadCookie(req, p.CookieName)
if err != nil {
// always http.ErrNoCookie
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
Expand Down Expand Up @@ -698,6 +784,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
if p.PassAccessToken && session.AccessToken != "" {
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
}
if p.PassAuthorization && session.IdToken != "" {
req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IdToken)}
}
if p.SetAuthorization && session.IdToken != "" {
rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IdToken))
}
if session.Email == "" {
rw.Header().Set("GAP-Auth", session.User)
} else {
Expand Down
79 changes: 75 additions & 4 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -92,6 +93,73 @@ func TestRobotsTxt(t *testing.T) {
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
}

func randomString(length int) string {
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}

func TestSplitCookie(t *testing.T) {
c1 := &http.Cookie{
Name: "cookie-name",
Value: randomString(5120),
Path: "/",
Domain: "foo.bar",
HttpOnly: true,
Secure: true,
Expires: time.Now(),
}
cookies := splitCookie(c1)
assert.Equal(t, 2, len(cookies))

assert.Equal(t, c1.Name+"-0", cookies[0].Name)
assert.Equal(t, c1.Name+"-1", cookies[1].Name)

assert.Equal(t, 3840, len(cookies[0].Value))
assert.Equal(t, 5120-3840, len(cookies[1].Value))

c2 := &http.Cookie{
Name: "cookie-name",
Value: randomString(3000),
Path: "/",
Domain: "foo.bar",
HttpOnly: true,
Secure: true,
Expires: time.Now(),
}

cookies2 := splitCookie(c2)
assert.Equal(t, 1, len(cookies2))

assert.Equal(t, c2.Name, cookies2[0].Name)
assert.Equal(t, c2.Value, cookies2[0].Value)
}

func TestJoinCookies(t *testing.T) {
c1 := &http.Cookie{
Name: "cookie-name",
Value: randomString(5120),
Path: "/",
Domain: "foo.bar",
HttpOnly: true,
Secure: true,
Expires: time.Now(),
}
// Split Cookies
cookies := splitCookie(c1)
assert.Equal(t, 2, len(cookies))

// join cookies should be the ivnerse
c2, _ := joinCookies(cookies)

assert.Equal(t, c1.Name, c2.Name)
assert.Equal(t, c1.Value, c2.Value)
}

type TestProvider struct {
*providers.ProviderData
EmailAddress string
Expand Down Expand Up @@ -504,7 +572,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
})
}

func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
}

Expand All @@ -513,7 +581,9 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
if err != nil {
return err
}
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
p.req.AddCookie(c)
}
return nil
}

Expand Down Expand Up @@ -802,8 +872,9 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
if err != nil {
panic(err)
}
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
req.AddCookie(cookie)
for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
req.AddCookie(c)
}
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
Expand Down
4 changes: 4 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type Options struct {
PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"`
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"`
SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"`
PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"`
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`

// These options allow for other providers besides Google, with
Expand Down Expand Up @@ -110,6 +112,8 @@ func NewOptions() *Options {
PassUserHeaders: true,
PassAccessToken: false,
PassHostHeader: true,
SetAuthorization: false,
PassAuthorization: false,
ApprovalPrompt: "force",
RequestLogging: true,
RequestLoggingFormat: defaultRequestLoggingFormat,
Expand Down
1 change: 1 addition & 0 deletions providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
}
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
IdToken: jsonResponse.IdToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: email,
Expand Down
1 change: 1 addition & 0 deletions providers/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er

s = &SessionState{
AccessToken: token.AccessToken,
IdToken: rawIDToken,
RefreshToken: token.RefreshToken,
ExpiresOn: token.Expiry,
Email: claims.Email,
Expand Down
Loading