Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
//
// You can configure it by passing an option struct to cors.New:
//
// c := cors.New(cors.Options{
// AllowedOrigins: []string{"foo.com"},
// AllowedMethods: []string{"GET", "POST", "DELETE"},
// AllowCredentials: true,
// })
// c := cors.New(cors.Options{
// AllowedOrigins: []string{"foo.com"},
// AllowedMethods: []string{"GET", "POST", "DELETE"},
// AllowCredentials: true,
// })
//
// Then insert the handler in the chain:
//
// handler = c.Handler(handler)
// handler = c.Handler(handler)
//
// See Options documentation for more options.
//
Expand Down Expand Up @@ -69,6 +69,17 @@ type Options struct {

// Debugging flag adds additional output to debug server side CORS issues
Debug bool

// ErrorHandler is a custom error handler function for handling CORS errors.
// if you want to write success response or call next handler, return true.
// If you want to terminate further processing, return false.
ErrorHandler func(w http.ResponseWriter, r *http.Request, cors Cors, err error) bool
}

// defaultErrorHandler is the default error handler for the CORS middleware.
func defaultErrorHandler(_ http.ResponseWriter, _ *http.Request, cors Cors, err error) bool {
cors.logf("%v", err)
return true
}

// Logger generic interface for logger
Expand Down Expand Up @@ -108,6 +119,8 @@ type Cors struct {

allowCredentials bool
optionPassthrough bool

errorHandler func(w http.ResponseWriter, r *http.Request, cors Cors, err error) bool
}

// New creates a new Cors handler with the provided options.
Expand All @@ -118,11 +131,16 @@ func New(options Options) *Cors {
allowCredentials: options.AllowCredentials,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
errorHandler: options.ErrorHandler,
}
if options.Debug && c.Log == nil {
c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
}

if c.errorHandler == nil {
c.errorHandler = defaultErrorHandler
}

// Normalize options
// Note: for origins and methods matching, the spec requires a case-sensitive matching.
// As it may error prone, we chose to ignore the spec here.
Expand Down Expand Up @@ -212,32 +230,35 @@ func (c *Cors) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
c.logf("Handler: Preflight request")
c.handlePreflight(w, r)
ret := c.handlePreflight(w, r)
// Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example
// is authentication middleware ; OPTIONS requests won't carry authentication
// headers (see #1)
if c.optionPassthrough {
next.ServeHTTP(w, r)
if ret {
next.ServeHTTP(w, r)
}
} else {
w.WriteHeader(http.StatusOK)
if ret {
w.WriteHeader(http.StatusOK)
}
}
} else {
c.logf("Handler: Actual request")
c.handleActualRequest(w, r)
next.ServeHTTP(w, r)
if c.handleActualRequest(w, r) {
next.ServeHTTP(w, r)
}
}
})
}

// handlePreflight handles pre-flight CORS requests
func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) bool {
headers := w.Header()
origin := r.Header.Get("Origin")

if r.Method != http.MethodOptions {
c.logf("Preflight aborted: %s!=OPTIONS", r.Method)
return
return c.errorHandler(w, r, *c, &PreflightNotOptionMethodError{Method: r.Method})
}
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
Expand All @@ -247,23 +268,19 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
headers.Add("Vary", "Access-Control-Request-Headers")

if origin == "" {
c.logf("Preflight aborted: empty origin")
return
return c.errorHandler(w, r, *c, &PreflightEmptyOriginError{})
}
if !c.isOriginAllowed(r, origin) {
c.logf("Preflight aborted: origin '%s' not allowed", origin)
return
return c.errorHandler(w, r, *c, &PreflightNotOriginAllowedError{Origin: origin})
}

reqMethod := r.Header.Get("Access-Control-Request-Method")
if !c.isMethodAllowed(reqMethod) {
c.logf("Preflight aborted: method '%s' not allowed", reqMethod)
return
return c.errorHandler(w, r, *c, &PreflightNotAllowedMethodError{RequestMethod: reqMethod})
}
reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
if !c.areHeadersAllowed(reqHeaders) {
c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders)
return
return c.errorHandler(w, r, *c, &PreflightNotHeadersAllowedError{RequestHeaders: reqHeaders})
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
Expand All @@ -286,32 +303,30 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
}
c.logf("Preflight response headers: %v", headers)

return true
}

// handleActualRequest handles simple cross-origin requests, actual request or redirects
func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) bool {
headers := w.Header()
origin := r.Header.Get("Origin")

// Always set Vary, see https://github.com/rs/cors/issues/10
headers.Add("Vary", "Origin")
if origin == "" {
c.logf("Actual request no headers added: missing origin")
return
return c.errorHandler(w, r, *c, &ActualMissingOriginError{})
}
if !c.isOriginAllowed(r, origin) {
c.logf("Actual request no headers added: origin '%s' not allowed", origin)
return
return c.errorHandler(w, r, *c, &ActualOriginNotAllowedError{Origin: origin})
}

// Note that spec does define a way to specifically disallow a simple method like GET or
// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
// We think it's a nice feature to be able to have control on those methods though.
if !c.isMethodAllowed(r.Method) {
c.logf("Actual request no headers added: method '%s' not allowed", r.Method)

return
return c.errorHandler(w, r, *c, &ActualMethodNotAllowedError{RequestMethod: r.Method})
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
Expand All @@ -325,6 +340,8 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
headers.Set("Access-Control-Allow-Credentials", "true")
}
c.logf("Actual response added headers: %v", headers)

return true
}

// convenience method. checks if a logger is set.
Expand Down
Loading