Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 39 additions & 59 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
//
// You can configure it by passing an option struct to cors.New:
//
// c := cors.New(cors.Options{
// AllowedOrigins: []string{"foo.com"},
// AllowedMethods: []string{"GET", "POST", "DELETE"},
// AllowCredentials: true,
// })
// c := cors.New(cors.Options{
// AllowedOrigins: []string{"foo.com"},
// AllowedMethods: []string{"GET", "POST", "DELETE"},
// AllowCredentials: true,
// })
//
// Then insert the handler in the chain:
//
// handler = c.Handler(handler)
// handler = c.Handler(handler)
//
// See Options documentation for more options.
//
Expand All @@ -24,6 +24,8 @@ import (
"os"
"strconv"
"strings"

"github.com/scylladb/go-set/strset"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid introducing a new external package?

Can we simply use map[string]struct{}?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but if we did I would be copy-pasting significant part of of that package anyway, so I don't think it would be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not introduce external dependencies to chi core package, unless necessary.

)

// Options is a configuration container to setup the CORS middleware.
Expand Down Expand Up @@ -81,23 +83,23 @@ type Cors struct {
// Debug logger
Log Logger

// Normalized list of plain allowed origins
allowedOrigins []string
// Normalized set of plain allowed origins
allowedOrigins *strset.Set

// List of allowed origins containing wildcards
allowedWOrigins []wildcard

// Optional origin validator function
allowOriginFunc func(r *http.Request, origin string) bool

// Normalized list of allowed headers
allowedHeaders []string
// Normalized set of allowed headers
allowedHeaders *strset.Set

// Normalized list of allowed methods
allowedMethods []string
// Normalized set of allowed methods
allowedMethods *strset.Set

// Normalized list of exposed headers
exposedHeaders []string
// Normalized set of exposed headers
exposedHeaders *strset.Set
maxAge int

// Set to true when allowed origins contains a "*"
Expand All @@ -113,7 +115,7 @@ type Cors struct {
// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
c := &Cors{
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
exposedHeaders: strset.New(convert(options.ExposedHeaders, http.CanonicalHeaderKey)...),
allowOriginFunc: options.AllowOriginFunc,
allowCredentials: options.AllowCredentials,
maxAge: options.MaxAge,
Expand All @@ -134,49 +136,46 @@ func New(options Options) *Cors {
c.allowedOriginsAll = true
}
} else {
c.allowedOrigins = []string{}
c.allowedOrigins = strset.NewWithSize(len(options.AllowedOrigins))
c.allowedWOrigins = []wildcard{}
for _, origin := range options.AllowedOrigins {
// Normalize
origin = strings.ToLower(origin)
if origin == "*" {
// If "*" is present in the list, turn the whole list into a match all
c.allowedOriginsAll = true
c.allowedOrigins = nil
c.allowedOrigins.Clear()
c.allowedWOrigins = nil
break
} else if i := strings.IndexByte(origin, '*'); i >= 0 {
// Split the origin in two: start and end string without the *
w := wildcard{origin[0:i], origin[i+1:]}
c.allowedWOrigins = append(c.allowedWOrigins, w)
} else {
c.allowedOrigins = append(c.allowedOrigins, origin)
c.allowedOrigins.Add(origin)
}
}
}

// Allowed Headers
if len(options.AllowedHeaders) == 0 {
// Use sensible defaults
c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
c.allowedHeaders = strset.New("Origin", "Accept", "Content-Type")
} else {
// Origin is always appended as some browsers will always request for this header at preflight
c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
for _, h := range options.AllowedHeaders {
if h == "*" {
c.allowedHeadersAll = true
c.allowedHeaders = nil
break
}
c.allowedHeaders = strset.New(convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)...)
if c.allowedHeaders.Has("*") {
c.allowedHeadersAll = true
c.allowedHeaders.Clear()
}
}

// Allowed Methods
if len(options.AllowedMethods) == 0 {
// Default is spec's "simple" methods
c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead}
c.allowedMethods = strset.New(http.MethodGet, http.MethodPost, http.MethodHead)
} else {
c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
c.allowedMethods = strset.New(convert(options.AllowedMethods, strings.ToUpper)...)
}

return c
Expand Down Expand Up @@ -273,11 +272,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
if len(reqHeaders) > 0 {
if !reqHeaders.IsEmpty() {

// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders.List(), ", "))
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
Expand Down Expand Up @@ -318,8 +317,8 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
} else {
headers.Set("Access-Control-Allow-Origin", origin)
}
if len(c.exposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
if !c.exposedHeaders.IsEmpty() {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders.List(), ", "))
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
Expand All @@ -344,11 +343,10 @@ func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
return true
}
origin = strings.ToLower(origin)
for _, o := range c.allowedOrigins {
if o == origin {
return true
}
if c.allowedOrigins.Has(origin) {
return true
}

for _, w := range c.allowedWOrigins {
if w.match(origin) {
return true
Expand All @@ -360,7 +358,7 @@ func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
// isMethodAllowed checks if a given method can be used as part of a cross-domain request
// on the endpoint
func (c *Cors) isMethodAllowed(method string) bool {
if len(c.allowedMethods) == 0 {
if c.allowedMethods.IsEmpty() {
// If no method allowed, always return false, even for preflight request
return false
}
Expand All @@ -369,32 +367,14 @@ func (c *Cors) isMethodAllowed(method string) bool {
// Always allow preflight requests
return true
}
for _, m := range c.allowedMethods {
if m == method {
return true
}
}
return false
return c.allowedMethods.Has(method)
}

// areHeadersAllowed checks if a given list of headers are allowed to used within
// a cross-domain request.
func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
if c.allowedHeadersAll || len(requestedHeaders) == 0 {
func (c *Cors) areHeadersAllowed(requestedHeaders *strset.Set) bool {
if c.allowedHeadersAll || requestedHeaders.IsEmpty() {
return true
}
for _, header := range requestedHeaders {
header = http.CanonicalHeaderKey(header)
found := false
for _, h := range c.allowedHeaders {
if h == header {
found = true
break
}
}
if !found {
return false
}
}
return true
return c.allowedHeaders.IsSubset(requestedHeaders)
}
14 changes: 13 additions & 1 deletion cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"net/http"
"net/http/httptest"
"regexp"
"sort"
"strings"
"testing"

"github.com/scylladb/go-set/strset"
)

var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -26,6 +29,15 @@ func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]s
for _, name := range allHeaders {
got := strings.Join(resHeaders[name], ", ")
want := expHeaders[name]
if name == "Access-Control-Allow-Headers" || name == "Access-Control-Expose-Headers" {
gSplit := strings.Split(got, ", ")
sort.Strings(gSplit)
got = strings.Join(gSplit, ", ")

wSplit := strings.Split(want, ", ")
sort.Strings(wSplit)
want = strings.Join(wSplit, ", ")
}
if got != want {
t.Errorf("Response header %q = %q, want %q", name, got, want)
}
Expand Down Expand Up @@ -488,7 +500,7 @@ func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) {
s := New(Options{
// Intentionally left blank.
})
s.allowedMethods = []string{}
s.allowedMethods = strset.New()
if s.isMethodAllowed("") {
t.Error("IsMethodAllowed should return false when c.allowedMethods is nil.")
}
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/go-chi/cors

go 1.14

require github.com/scylladb/go-set v1.0.2
13 changes: 9 additions & 4 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package cors

import "strings"
import (
"net/http"
"strings"

"github.com/scylladb/go-set/strset"
)

const toLower = 'a' - 'A'

Expand All @@ -25,7 +30,7 @@ func convert(s []string, c converter) []string {
}

// parseHeaderList tokenize + normalize a string containing a list of headers
func parseHeaderList(headerList string) []string {
func parseHeaderList(headerList string) *strset.Set {
l := len(headerList)
h := make([]byte, 0, l)
upper := true
Expand All @@ -36,7 +41,7 @@ func parseHeaderList(headerList string) []string {
t++
}
}
headers := make([]string, 0, t)
headers := strset.NewWithSize(t)
for i := 0; i < l; i++ {
b := headerList[i]
if b >= 'a' && b <= 'z' {
Expand All @@ -58,7 +63,7 @@ func parseHeaderList(headerList string) []string {
if b == ' ' || b == ',' || i == l-1 {
if len(h) > 0 {
// Flush the found header
headers = append(headers, string(h))
headers.Add(http.CanonicalHeaderKey(string(h)))
h = h[:0]
upper = true
}
Expand Down
10 changes: 6 additions & 4 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cors
import (
"strings"
"testing"

"github.com/scylladb/go-set/strset"
)

func TestWildcard(t *testing.T) {
Expand Down Expand Up @@ -33,17 +35,17 @@ func TestConvert(t *testing.T) {

func TestParseHeaderList(t *testing.T) {
h := parseHeaderList("header, second-header, THIRD-HEADER, Numb3r3d-H34d3r, Header_with_underscore Header.with.full.stop")
e := []string{"Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r", "Header_with_underscore", "Header.with.full.stop"}
if h[0] != e[0] || h[1] != e[1] || h[2] != e[2] || h[3] != e[3] || h[4] != e[4] || h[5] != e[5] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slices.Equal or maps.Equal will do

e := strset.New("Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r", "Header_with_underscore", "Header.with.full.stop")
if !h.IsEqual(h) {
t.Errorf("%v != %v", h, e)
}
}

func TestParseHeaderListEmpty(t *testing.T) {
if len(parseHeaderList("")) != 0 {
if !parseHeaderList("").IsEmpty() {
t.Error("should be empty slice")
}
if len(parseHeaderList(" , ")) != 0 {
if !parseHeaderList(" , ").IsEmpty() {
t.Error("should be empty slice")
}
}
Expand Down
Loading