Skip to content

issues/179 Add option to use base64 URL-safe encoding format for CSRF token #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrf.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csrf

import (
"encoding/base64"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -100,6 +101,7 @@ type options struct {
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
HttpOnly bool
Secure bool
URLSafe bool
SameSite SameSiteMode
RequestHeader string
FieldName string
Expand Down Expand Up @@ -248,7 +250,11 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Save the masked token to the request context
r = contextSave(r, tokenKey, mask(realToken, r))
encoding := base64.StdEncoding
if cs.opts.URLSafe {
encoding = base64.URLEncoding
}
r = contextSave(r, tokenKey, mask(realToken, r, encoding))
// Save the field name to the request context
r = contextSave(r, formKey, cs.opts.FieldName)

Expand Down
12 changes: 9 additions & 3 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TemplateField(r *http.Request) template.HTML {
// token and returning them together as a 64-byte slice. This effectively
// randomises the token on a per-request basis without breaking multiple browser
// tabs/windows.
func mask(realToken []byte, _ *http.Request) string {
func mask(realToken []byte, _ *http.Request, encoding *base64.Encoding) string {
otp, err := generateRandomBytes(tokenLength)
if err != nil {
return ""
Expand All @@ -83,7 +83,7 @@ func mask(realToken []byte, _ *http.Request) string {
// XOR the OTP with the real token to generate a masked token. Append the
// OTP to the front of the masked token to allow unmasking in the subsequent
// request.
return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
return encoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
}

// unmask splits the issued token (one-time-pad + masked token) and returns the
Expand Down Expand Up @@ -129,7 +129,13 @@ func (cs *csrf) requestToken(r *http.Request) ([]byte, error) {

// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
encoding := base64.StdEncoding

if cs.opts.URLSafe {
encoding = base64.URLEncoding
}

decoded, err := encoding.DecodeString(issued)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ func TestMaskUnmaskTokens(t *testing.T) {
t.Fatal(err)
}

issued := mask(realToken, nil)
decoded, err := base64.StdEncoding.DecodeString(issued)
encoding := base64.StdEncoding

issued := mask(realToken, nil, encoding)
decoded, err := encoding.DecodeString(issued)
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ func TrustedOrigins(origins []string) Option {
}
}

// URLSafe changes the base64 encoding format ( URL safe ) of the CSRF token.
func URLSafe(s bool) Option {
return func(cs *csrf) {
cs.opts.URLSafe = s
}
}

// setStore sets the store used by the CSRF middleware.
// Note: this is private (for now) to allow for internal API changes.
func setStore(s store) Option {
Expand Down
20 changes: 20 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,23 @@ func TestMaxAge(t *testing.T) {
})

}

func TestURLSafe(t *testing.T) {
t.Run("Ensure the default URLSafe is applied", func(t *testing.T) {
handler := Protect(testKey)(nil)
cs := handler.(*csrf)

if cs.opts.URLSafe != false {
t.Fatalf("default URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, false)
}
})

t.Run("Support an explicit URLSafe of true", func(t *testing.T) {
handler := Protect(testKey, URLSafe(true))(nil)
cs := handler.(*csrf)

if cs.opts.URLSafe != true {
t.Fatalf("URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, true)
}
})
}