Skip to content

Commit 709bbe3

Browse files
committed
Merge branch 'master' of github.com:ant0ine/go-json-rest
2 parents af20cdf + ed84d40 commit 709bbe3

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

Diff for: rest/recorder.go

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ type recorderResponseWriter struct {
4444
// Record the status code.
4545
func (w *recorderResponseWriter) WriteHeader(code int) {
4646
w.ResponseWriter.WriteHeader(code)
47+
if w.wroteHeader {
48+
return
49+
}
4750
w.statusCode = code
4851
w.wroteHeader = true
4952
}

Diff for: rest/recorder_test.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package rest
22

33
import (
4-
"github.com/ant0ine/go-json-rest/rest/test"
54
"testing"
5+
6+
"github.com/ant0ine/go-json-rest/rest/test"
67
)
78

89
func TestRecorderMiddleware(t *testing.T) {
@@ -91,3 +92,43 @@ func TestRecorderAndGzipMiddleware(t *testing.T) {
9192
recorded.CodeIs(200)
9293
recorded.ContentTypeIsJson()
9394
}
95+
96+
//Underlying net/http only allows you to set the status code once
97+
func TestRecorderMiddlewareReportsSameStatusCodeAsResponse(t *testing.T) {
98+
api := NewApi()
99+
const firstCode = 400
100+
const secondCode = 500
101+
102+
// a middleware carrying the Env tests
103+
api.Use(MiddlewareSimple(func(handler HandlerFunc) HandlerFunc {
104+
return func(w ResponseWriter, r *Request) {
105+
106+
handler(w, r)
107+
108+
if r.Env["STATUS_CODE"] == nil {
109+
t.Error("STATUS_CODE is nil")
110+
}
111+
statusCode := r.Env["STATUS_CODE"].(int)
112+
if statusCode != firstCode {
113+
t.Errorf("STATUS_CODE = %d expected, got %d", firstCode, statusCode)
114+
}
115+
}
116+
}))
117+
118+
// the middleware to test
119+
api.Use(&RecorderMiddleware{})
120+
121+
// a simple app
122+
api.SetApp(AppSimple(func(w ResponseWriter, r *Request) {
123+
w.WriteHeader(firstCode)
124+
w.WriteHeader(secondCode)
125+
}))
126+
127+
// wrap all
128+
handler := api.MakeHandler()
129+
130+
req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
131+
recorded := test.RunRequest(t, handler, req)
132+
recorded.CodeIs(firstCode)
133+
recorded.ContentTypeIsJson()
134+
}

0 commit comments

Comments
 (0)