Skip to content

Commit a840e0d

Browse files
author
John Boggs
authored
Merge pull request #2 from segmentio/SEC-757-okta-refresh
Sec 757 okta refresh
2 parents db3a898 + 06c6300 commit a840e0d

File tree

8 files changed

+371
-1
lines changed

8 files changed

+371
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ _testmain.go
2929

3030
# Editor swap/temp files
3131
.*.swp
32+
33+
local

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Valid providers are :
3737
* [GitHub](#github-auth-provider)
3838
* [GitLab](#gitlab-auth-provider)
3939
* [LinkedIn](#linkedin-auth-provider)
40+
* [Okta](#okta-auth-provider)
4041

4142
The provider can be selected using the `provider` configuration value.
4243

@@ -155,6 +156,10 @@ OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many ma
155156
-cookie-secure=false
156157
-email-domain example.com
157158

159+
### Okta Auth Provider
160+
161+
[Okta](https://www.okta.com/) is a hosted SSO provider. You will need to set the `okta-domain` to your organization's Okta domain.
162+
158163
## Email Authentication
159164

160165
To authorize by email domain use `--email-domain=yourcompany.com`. To authorize individual email addresses use `--authenticated-emails-file=/path/to/file` with one email per line. To authorize all email addresses use `--email-domain=*`.

main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func main() {
4949
flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).")
5050
flagSet.String("google-admin-email", "", "the google admin to impersonate for api calls")
5151
flagSet.String("google-service-account-json", "", "the path to the service account json credentials")
52+
flagSet.String("okta-domain", "", "the full domain for which your organization's okta is configured (example.okta.com)")
5253
flagSet.String("client-id", "", "the OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"")
5354
flagSet.String("client-secret", "", "the OAuth Client Secret")
5455
flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)")

oauthproxy.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,16 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error)
427427

428428
redirect = req.Form.Get("rd")
429429
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
430-
redirect = "/"
430+
redirect = req.URL.RequestURI()
431+
}
432+
433+
if req.Header.Get("X-Auth-Request-Redirect") != "" {
434+
redirect = req.Header.Get("X-Auth-Request-Redirect")
431435
}
432436

437+
if redirect == p.SignInPath || redirect == p.OAuthStartPath {
438+
redirect = "/"
439+
}
433440
return
434441
}
435442

options.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Options struct {
3737
GoogleGroups []string `flag:"google-group" cfg:"google_group"`
3838
GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"`
3939
GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"`
40+
OktaDomain string `flag:"okta-domain" cfg:"okta_domain"`
4041
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
4142
DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"`
4243
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"`
@@ -278,6 +279,8 @@ func parseProviderInfo(o *Options, msgs []string) []string {
278279
} else {
279280
p.Verifier = o.oidcVerifier
280281
}
282+
case *providers.OktaProvider:
283+
p.SetOktaDomain(o.OktaDomain)
281284
}
282285
return msgs
283286
}

providers/okta.go

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
package providers
2+
3+
import (
4+
"bytes"
5+
"encoding/base64"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io/ioutil"
10+
"log"
11+
"net/http"
12+
"net/url"
13+
"strings"
14+
"time"
15+
16+
"github.com/bitly/oauth2_proxy/api"
17+
)
18+
19+
type OktaProvider struct {
20+
*ProviderData
21+
}
22+
23+
func NewOktaProvider(p *ProviderData) *OktaProvider {
24+
p.ProviderName = "Okta"
25+
if p.Scope == "" {
26+
p.Scope = "openid profile email offline_access"
27+
}
28+
return &OktaProvider{ProviderData: p}
29+
}
30+
31+
func (p *OktaProvider) SetOktaDomain(domain string) {
32+
if p.LoginURL == nil || p.LoginURL.String() == "" {
33+
p.LoginURL = &url.URL{
34+
Scheme: "https",
35+
Host: domain,
36+
Path: "/oauth2/v1/authorize",
37+
}
38+
}
39+
if p.RedeemURL == nil || p.RedeemURL.String() == "" {
40+
p.RedeemURL = &url.URL{
41+
Scheme: "https",
42+
Host: domain,
43+
Path: "/oauth2/v1/token",
44+
}
45+
}
46+
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
47+
p.ValidateURL = &url.URL{
48+
Scheme: "https",
49+
Host: domain,
50+
Path: "/oauth2/v1/userinfo",
51+
}
52+
}
53+
54+
}
55+
56+
func getOktaHeader(access_token string) http.Header {
57+
header := make(http.Header)
58+
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
59+
return header
60+
}
61+
62+
func emailFromOktaIdToken(idToken string) (string, error) {
63+
64+
// id_token is a base64 encode ID token payload
65+
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
66+
jwt := strings.Split(idToken, ".")
67+
b, err := base64.RawURLEncoding.DecodeString(jwt[1])
68+
if err != nil {
69+
return "", err
70+
}
71+
72+
var email struct {
73+
Email string `json:"email"`
74+
}
75+
err = json.Unmarshal(b, &email)
76+
if err != nil {
77+
return "", err
78+
}
79+
if email.Email == "" {
80+
return "", errors.New("Okta ID token missing email")
81+
}
82+
83+
return email.Email, nil
84+
}
85+
86+
func (p *OktaProvider) GetUserName(s *SessionState) (string, error) {
87+
req, err := http.NewRequest("GET",
88+
p.ValidateURL.String(), nil)
89+
if err != nil {
90+
log.Printf("failed building request %s", err)
91+
return "", err
92+
}
93+
req.Header = getOktaHeader(s.AccessToken)
94+
json, err := api.Request(req)
95+
if err != nil {
96+
log.Printf("failed making request %s", err)
97+
return "", err
98+
}
99+
return json.Get("preferred_username").String()
100+
}
101+
102+
func (p *OktaProvider) ValidateSessionState(s *SessionState) bool {
103+
return validateToken(p, s.AccessToken, getOktaHeader(s.AccessToken))
104+
}
105+
106+
func (p *OktaProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
107+
if code == "" {
108+
err = errors.New("missing code")
109+
return
110+
}
111+
112+
params := url.Values{}
113+
params.Add("redirect_uri", redirectURL)
114+
params.Add("client_id", p.ClientID)
115+
params.Add("client_secret", p.ClientSecret)
116+
params.Add("code", code)
117+
params.Add("grant_type", "authorization_code")
118+
var req *http.Request
119+
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
120+
if err != nil {
121+
return
122+
}
123+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
124+
125+
resp, err := http.DefaultClient.Do(req)
126+
if err != nil {
127+
return
128+
}
129+
var body []byte
130+
body, err = ioutil.ReadAll(resp.Body)
131+
resp.Body.Close()
132+
if err != nil {
133+
return
134+
}
135+
136+
if resp.StatusCode != 200 {
137+
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
138+
return
139+
}
140+
141+
var jsonResponse struct {
142+
AccessToken string `json:"access_token"`
143+
RefreshToken string `json:"refresh_token"`
144+
ExpiresIn int64 `json:"expires_in"`
145+
IdToken string `json:"id_token"`
146+
}
147+
err = json.Unmarshal(body, &jsonResponse)
148+
if err != nil {
149+
return
150+
}
151+
var email string
152+
email, err = emailFromOktaIdToken(jsonResponse.IdToken)
153+
154+
if err != nil {
155+
return
156+
}
157+
s = &SessionState{
158+
AccessToken: jsonResponse.AccessToken,
159+
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
160+
RefreshToken: jsonResponse.RefreshToken,
161+
Email: email,
162+
}
163+
return
164+
}
165+
166+
func (p *OktaProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
167+
//@todo: Remove from here before accepting PR. This is for local testing
168+
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
169+
return false, nil
170+
}
171+
172+
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
173+
if err != nil {
174+
return false, err
175+
}
176+
177+
// re-check that the user is in the proper google group(s)
178+
if !p.ValidateGroup(s.Email) {
179+
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
180+
}
181+
182+
origExpiration := s.ExpiresOn
183+
s.AccessToken = newToken
184+
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
185+
log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
186+
return true, nil
187+
}
188+
189+
func (p *OktaProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
190+
params := url.Values{}
191+
params.Add("client_id", p.ClientID)
192+
params.Add("client_secret", p.ClientSecret)
193+
params.Add("refresh_token", refreshToken)
194+
params.Add("grant_type", "refresh_token")
195+
var req *http.Request
196+
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
197+
if err != nil {
198+
return
199+
}
200+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
201+
202+
resp, err := http.DefaultClient.Do(req)
203+
if err != nil {
204+
return
205+
}
206+
var body []byte
207+
body, err = ioutil.ReadAll(resp.Body)
208+
resp.Body.Close()
209+
if err != nil {
210+
return
211+
}
212+
213+
if resp.StatusCode != 200 {
214+
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
215+
return
216+
}
217+
218+
var data struct {
219+
AccessToken string `json:"access_token"`
220+
ExpiresIn int64 `json:"expires_in"`
221+
}
222+
err = json.Unmarshal(body, &data)
223+
if err != nil {
224+
return
225+
}
226+
token = data.AccessToken
227+
expires = time.Duration(data.ExpiresIn) * time.Second
228+
return
229+
}

0 commit comments

Comments
 (0)