Skip to content

Commit 460e5fb

Browse files
authored
feat: add response wrapper (#27)
1 parent 076eb0f commit 460e5fb

File tree

6 files changed

+206
-13
lines changed

6 files changed

+206
-13
lines changed

context.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
// It holds data related to current HTTP request.
1111
type Context struct {
1212
request *http.Request
13-
response http.ResponseWriter
13+
response ResponseWriter
1414
params Params
1515
storage Map
1616
kid *Kid
@@ -26,7 +26,7 @@ func newContext(k *Kid) *Context {
2626
// reset resets the context.
2727
func (c *Context) reset(request *http.Request, response http.ResponseWriter) {
2828
c.request = request
29-
c.response = response
29+
c.response = newResponse(response)
3030
c.storage = make(Map)
3131
c.params = make(Params)
3232
}
@@ -42,7 +42,7 @@ func (c *Context) Request() *http.Request {
4242
}
4343

4444
// Response returns plain response of current HTTP request.
45-
func (c *Context) Response() http.ResponseWriter {
45+
func (c *Context) Response() ResponseWriter {
4646
return c.response
4747
}
4848

@@ -171,6 +171,7 @@ func (c *Context) HTMLString(code int, tpl string) error {
171171
// NoContent returns an empty response with the given status code.
172172
func (c *Context) NoContent(code int) {
173173
c.response.WriteHeader(code)
174+
c.response.WriteHeaderNow()
174175
}
175176

176177
// writeContentType sets content type header for response.

context_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ func TestContextReset(t *testing.T) {
3636

3737
req := httptest.NewRequest(http.MethodGet, "/", nil)
3838
res := httptest.NewRecorder()
39+
expectedRes := newResponse(res)
3940

4041
ctx.reset(req, res)
4142

4243
assert.Equal(t, req, ctx.request)
43-
assert.Equal(t, res, ctx.response)
44+
assert.Equal(t, expectedRes, ctx.response)
4445
assert.Equal(t, make(Map), ctx.storage)
4546
assert.Equal(t, make(Params), ctx.params)
4647
}
@@ -59,10 +60,11 @@ func TestContextResponse(t *testing.T) {
5960
ctx := newContext(New())
6061

6162
res := httptest.NewRecorder()
63+
expectedRes := newResponse(res)
6264

6365
ctx.reset(nil, res)
6466

65-
assert.Equal(t, res, ctx.Response())
67+
assert.Equal(t, expectedRes, ctx.Response())
6668
}
6769

6870
func TestContextSetParams(t *testing.T) {
@@ -205,10 +207,6 @@ func TestNoContent(t *testing.T) {
205207

206208
ctx.NoContent(http.StatusNoContent)
207209
assert.Equal(t, http.StatusNoContent, res.Code)
208-
209-
// Once status code is written, it can't be rewritten again.
210-
ctx.NoContent(http.StatusOK)
211-
assert.Equal(t, http.StatusNoContent, res.Code)
212210
}
213211

214212
func TestContextReadJSON(t *testing.T) {

html_renderer/html.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ func (r *defaultHTMLRenderer) loadTemplates() error {
128128
return newInternalServerHTTPError(err, err.Error())
129129
}
130130

131-
fmt.Println(r.funcMap)
132-
133131
for _, templateFile := range templateFiles {
134132
name := r.getTemplateName(templateFile)
135133
files := getFilesToParse(templateFile, layoutFiles)

kid_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,13 +474,13 @@ func TestKidRun(t *testing.T) {
474474
})
475475

476476
go func() {
477-
assert.NoError(t, k.Run())
477+
assert.NoError(t, k.Run(":8080"))
478478
}()
479479

480480
// Wait for the server to start
481481
time.Sleep(5 * time.Millisecond)
482482

483-
resp, err := http.Get("http://localhost:2376")
483+
resp, err := http.Get("http://localhost:8080")
484484
assert.NoError(t, err)
485485

486486
defer resp.Body.Close()

response.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package kid
2+
3+
import (
4+
"bufio"
5+
"net"
6+
"net/http"
7+
)
8+
9+
type (
10+
// ResponseWriter is a wrapper for http.ResponseWriter
11+
// to make using http.ResponseWriter's methods easier.
12+
ResponseWriter interface {
13+
http.ResponseWriter
14+
http.Hijacker
15+
http.Flusher
16+
17+
// WriteHeaderNow writes status code.
18+
WriteHeaderNow()
19+
20+
// Size returns number of bytes written to response.
21+
Size() int
22+
23+
// Written returns true if response has already been written otherwise returns false.
24+
Written() bool
25+
}
26+
27+
// response implements ResponseWriter.
28+
response struct {
29+
http.ResponseWriter
30+
written bool
31+
status int
32+
size int
33+
}
34+
)
35+
36+
// Verifying interface compliance.
37+
var _ ResponseWriter = (*response)(nil)
38+
39+
// newResponse returns a new response writer.
40+
func newResponse(w http.ResponseWriter) ResponseWriter {
41+
response := response{
42+
ResponseWriter: w,
43+
status: http.StatusOK,
44+
}
45+
return &response
46+
}
47+
48+
// WriteHeader sets status code.
49+
func (r *response) WriteHeader(code int) {
50+
if r.Written() {
51+
return
52+
}
53+
54+
r.status = code
55+
}
56+
57+
// WriteHeaderNow writes status code.
58+
// Status code should already be specified using response.WriteHeader method.
59+
func (r *response) WriteHeaderNow() {
60+
if r.Written() {
61+
return
62+
}
63+
64+
r.written = true
65+
r.ResponseWriter.WriteHeader(r.status)
66+
}
67+
68+
// Write writes byte data to response.
69+
func (r *response) Write(b []byte) (int, error) {
70+
r.WriteHeaderNow()
71+
72+
n, err := r.ResponseWriter.Write(b)
73+
r.size += n
74+
75+
return n, err
76+
}
77+
78+
// Size returns number of bytes written.
79+
func (r *response) Size() int {
80+
return r.size
81+
}
82+
83+
// Written returns true if response has already been written otherwise returns false.
84+
func (r *response) Written() bool {
85+
return r.written
86+
}
87+
88+
// Flush implements the http.Flusher interface.
89+
func (r *response) Flush() {
90+
r.WriteHeaderNow()
91+
r.ResponseWriter.(http.Flusher).Flush()
92+
}
93+
94+
// Hijack implements the http.Hijacker interface.
95+
func (r *response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
96+
r.written = true
97+
return r.ResponseWriter.(http.Hijacker).Hijack()
98+
}

response_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package kid
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestNewResponse(t *testing.T) {
12+
w := httptest.NewRecorder()
13+
res := newResponse(w).(*response)
14+
15+
assert.Equal(t, w, res.ResponseWriter)
16+
assert.Equal(t, http.StatusOK, res.status)
17+
assert.Zero(t, res.Size())
18+
assert.False(t, res.Written())
19+
}
20+
21+
func TestResponseWriterWriteHeader(t *testing.T) {
22+
w := httptest.NewRecorder()
23+
res := newResponse(w).(*response)
24+
25+
res.WriteHeader(http.StatusAccepted)
26+
27+
assert.Equal(t, http.StatusAccepted, res.status)
28+
assert.False(t, res.Written())
29+
30+
res.WriteHeaderNow()
31+
32+
// Won't write again because header is already written.
33+
res.WriteHeader(http.StatusBadRequest)
34+
35+
assert.Equal(t, http.StatusAccepted, res.status)
36+
}
37+
38+
func TestResponseWriterWriteHeaderNow(t *testing.T) {
39+
w := httptest.NewRecorder()
40+
res := newResponse(w).(*response)
41+
42+
res.WriteHeader(http.StatusAccepted)
43+
res.WriteHeaderNow()
44+
45+
assert.True(t, res.Written())
46+
}
47+
48+
func TestResponseWriterSize(t *testing.T) {
49+
w := httptest.NewRecorder()
50+
res := newResponse(w)
51+
52+
n1, err := res.Write([]byte("Hello"))
53+
assert.NoError(t, err)
54+
55+
n2, err := res.Write([]byte("Bye"))
56+
assert.NoError(t, err)
57+
58+
assert.Equal(t, 8, n1+n2)
59+
assert.Equal(t, n1+n2, res.Size())
60+
}
61+
62+
func TestResponseWriterWritten(t *testing.T) {
63+
w := httptest.NewRecorder()
64+
res := newResponse(w)
65+
66+
assert.False(t, res.Written())
67+
68+
res.WriteHeaderNow()
69+
70+
assert.True(t, res.Written())
71+
}
72+
73+
func TestResponseWriterFlush(t *testing.T) {
74+
k := New()
75+
76+
k.GET("/", func(c *Context) error {
77+
c.Response().WriteHeader(http.StatusBadGateway)
78+
c.Response().Flush()
79+
return nil
80+
})
81+
82+
srv := httptest.NewServer(k)
83+
defer srv.Close()
84+
85+
resp, err := http.Get(srv.URL)
86+
assert.NoError(t, err)
87+
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
88+
}
89+
90+
func TestResponseWriterHijack(t *testing.T) {
91+
w := httptest.NewRecorder()
92+
res := newResponse(w)
93+
94+
assert.Panics(t, func() {
95+
_, _, _ = res.Hijack()
96+
})
97+
assert.True(t, res.Written())
98+
}

0 commit comments

Comments
 (0)