Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.

Commit 825e2ce

Browse files
authored
Merge pull request #109 from aubelsb2/master
Automatic/passive switching between V1 and V2 API Gateway API requests and responses based on JSON marshal/unmarshal interfaces for Gorillia, & Cookies fixed for v2 & Tests fixed for V2 (wrong structure)
2 parents f6f827b + 9e75c85 commit 825e2ce

11 files changed

+440
-34
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
*.out
1313

1414
.vscode
15+
.idea
1516
vendor/*/
1617
gin/aws-lambda-go-api-proxy-gin
1718
core/aws-lambda-go-api-proxy-core

core/requestv2.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ func (r *RequestAccessorV2) EventToRequest(req events.APIGatewayV2HTTPRequest) (
164164
return nil, err
165165
}
166166

167+
for _, cookie := range req.Cookies {
168+
httpRequest.Header.Add("Cookie", cookie)
169+
}
170+
167171
for headerKey, headerValue := range req.Headers {
168172
for _, val := range strings.Split(headerValue, ",") {
169173
httpRequest.Header.Add(headerKey, strings.Trim(val, " "))

core/requestv2_test.go

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package core_test
33
import (
44
"context"
55
"encoding/base64"
6+
"github.com/onsi/gomega/gstruct"
67
"io/ioutil"
78
"math/rand"
89
"os"
@@ -247,10 +248,10 @@ var _ = Describe("RequestAccessorV2 tests", func() {
247248
})
248249

249250
It("Populates stage variables correctly", func() {
250-
varsRequest := getProxyRequest("orders", "GET")
251+
varsRequest := getProxyRequestV2("orders", "GET")
251252
varsRequest.StageVariables = getStageVariables()
252253

253-
accessor := core.RequestAccessor{}
254+
accessor := core.RequestAccessorV2{}
254255
httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest)
255256
Expect(err).To(BeNil())
256257

@@ -262,7 +263,7 @@ var _ = Describe("RequestAccessorV2 tests", func() {
262263
Expect("value1").To(Equal(stageVars["var1"]))
263264
Expect("value2").To(Equal(stageVars["var2"]))
264265

265-
stageVars, ok := core.GetStageVarsFromContext(httpReq.Context())
266+
stageVars, ok := core.GetStageVarsFromContextV2(httpReq.Context())
266267
// not present in context
267268
Expect(ok).To(BeFalse())
268269

@@ -273,7 +274,7 @@ var _ = Describe("RequestAccessorV2 tests", func() {
273274
// should not be in headers
274275
Expect(err).ToNot(BeNil())
275276

276-
stageVars, ok = core.GetStageVarsFromContext(httpReq.Context())
277+
stageVars, ok = core.GetStageVarsFromContextV2(httpReq.Context())
277278
Expect(ok).To(BeTrue())
278279
Expect(2).To(Equal(len(stageVars)))
279280
Expect(stageVars["var1"]).ToNot(BeNil())
@@ -284,9 +285,9 @@ var _ = Describe("RequestAccessorV2 tests", func() {
284285

285286
It("Populates the default hostname correctly", func() {
286287

287-
basicRequest := getProxyRequest("orders", "GET")
288-
basicRequest.RequestContext = getRequestContext()
289-
accessor := core.RequestAccessor{}
288+
basicRequest := getProxyRequestV2("orders", "GET")
289+
basicRequest.RequestContext = getRequestContextV2()
290+
accessor := core.RequestAccessorV2{}
290291
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
291292
Expect(err).To(BeNil())
292293

@@ -297,8 +298,8 @@ var _ = Describe("RequestAccessorV2 tests", func() {
297298
It("Uses a custom hostname", func() {
298299
myCustomHost := "http://my-custom-host.com"
299300
os.Setenv(core.CustomHostVariable, myCustomHost)
300-
basicRequest := getProxyRequest("orders", "GET")
301-
accessor := core.RequestAccessor{}
301+
basicRequest := getProxyRequestV2("orders", "GET")
302+
accessor := core.RequestAccessorV2{}
302303
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
303304
Expect(err).To(BeNil())
304305

@@ -310,15 +311,28 @@ var _ = Describe("RequestAccessorV2 tests", func() {
310311
It("Strips terminating / from hostname", func() {
311312
myCustomHost := "http://my-custom-host.com"
312313
os.Setenv(core.CustomHostVariable, myCustomHost+"/")
313-
basicRequest := getProxyRequest("orders", "GET")
314-
accessor := core.RequestAccessor{}
314+
basicRequest := getProxyRequestV2("orders", "GET")
315+
accessor := core.RequestAccessorV2{}
315316
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
316317
Expect(err).To(BeNil())
317318

318319
Expect(myCustomHost).To(Equal("http://" + httpReq.Host))
319320
Expect(myCustomHost).To(Equal("http://" + httpReq.URL.Host))
320321
os.Unsetenv(core.CustomHostVariable)
321322
})
323+
324+
It("handles cookies okay", func() {
325+
basicRequest := getProxyRequestV2("orders", "GET")
326+
basicRequest.Cookies = []string{
327+
"TestCookie=123",
328+
}
329+
accessor := core.RequestAccessorV2{}
330+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
331+
Expect(err).To(BeNil())
332+
Expect(httpReq.Cookie("TestCookie")).To(gstruct.PointTo(gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{
333+
"Value": Equal("123"),
334+
})))
335+
})
322336
})
323337
})
324338

core/responsev2.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,13 @@ func (r *ProxyResponseWriterV2) GetProxyResponse() (events.APIGatewayV2HTTPRespo
102102
}
103103

104104
headers := make(map[string]string)
105+
cookies := make([]string, 0)
105106

106107
for headerKey, headerValue := range http.Header(r.headers) {
108+
if strings.EqualFold("set-cookie", headerKey) {
109+
cookies = append(cookies, headerValue...)
110+
continue
111+
}
107112
headers[headerKey] = strings.Join(headerValue, ",")
108113
}
109114

@@ -112,5 +117,6 @@ func (r *ProxyResponseWriterV2) GetProxyResponse() (events.APIGatewayV2HTTPRespo
112117
Headers: headers,
113118
Body: output,
114119
IsBase64Encoded: isBase64,
120+
Cookies: cookies,
115121
}, nil
116122
}

core/responsev2_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,27 @@ var _ = Describe("ResponseWriterV2 tests", func() {
158158
})
159159

160160
It("Writes multi-value headers correctly", func() {
161+
response := NewProxyResponseWriterV2()
162+
response.Header().Add("Accepts", "foobar")
163+
response.Header().Add("Accepts", "barfoo")
164+
response.Write([]byte("hello"))
165+
proxyResponse, err := response.GetProxyResponse()
166+
Expect(err).To(BeNil())
167+
168+
Expect(2).To(Equal(len(proxyResponse.Headers)))
169+
Expect("foobar,barfoo").To(Equal(proxyResponse.Headers["Accepts"]))
170+
})
171+
172+
It("Writes cookies correctly", func() {
161173
response := NewProxyResponseWriterV2()
162174
response.Header().Add("Set-Cookie", "csrftoken=foobar")
163175
response.Header().Add("Set-Cookie", "session_id=barfoo")
164176
response.Write([]byte("hello"))
165177
proxyResponse, err := response.GetProxyResponse()
166178
Expect(err).To(BeNil())
167179

168-
Expect(2).To(Equal(len(proxyResponse.Headers)))
169-
Expect("csrftoken=foobar,session_id=barfoo").To(Equal(proxyResponse.Headers["Set-Cookie"]))
180+
Expect(2).To(Equal(len(proxyResponse.Cookies)))
181+
Expect(strings.Split("csrftoken=foobar,session_id=barfoo", ",")).To(Equal(proxyResponse.Cookies))
170182
})
171183
})
172184

core/switchablerequest.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package core
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"github.com/aws/aws-lambda-go/events"
7+
)
8+
9+
type SwitchableAPIGatewayRequest struct {
10+
v interface{} // v is Always nil, or a pointer of APIGatewayProxyRequest or APIGatewayV2HTTPRequest
11+
}
12+
13+
// NewSwitchableAPIGatewayRequestV1 creates a new SwitchableAPIGatewayRequest from APIGatewayProxyRequest
14+
func NewSwitchableAPIGatewayRequestV1(v *events.APIGatewayProxyRequest) *SwitchableAPIGatewayRequest {
15+
return &SwitchableAPIGatewayRequest{
16+
v: v,
17+
}
18+
}
19+
// NewSwitchableAPIGatewayRequestV2 creates a new SwitchableAPIGatewayRequest from APIGatewayV2HTTPRequest
20+
func NewSwitchableAPIGatewayRequestV2(v *events.APIGatewayV2HTTPRequest) *SwitchableAPIGatewayRequest {
21+
return &SwitchableAPIGatewayRequest{
22+
v: v,
23+
}
24+
}
25+
26+
// MarshalJSON is a pass through serialization
27+
func (s *SwitchableAPIGatewayRequest) MarshalJSON() ([]byte, error) {
28+
return json.Marshal(s.v)
29+
}
30+
31+
// UnmarshalJSON is a switching serialization based on the presence of fields in the
32+
// source JSON, multiValueQueryStringParameters for APIGatewayProxyRequest and rawQueryString for
33+
// APIGatewayV2HTTPRequest.
34+
func (s *SwitchableAPIGatewayRequest) UnmarshalJSON(b []byte) error {
35+
delta := map[string]json.RawMessage{}
36+
if err := json.Unmarshal(b, &delta); err != nil {
37+
return err
38+
}
39+
_, v1test := delta["multiValueQueryStringParameters"]
40+
_, v2test := delta["rawQueryString"]
41+
s.v = nil
42+
if v1test && !v2test {
43+
s.v = &events.APIGatewayProxyRequest{}
44+
} else if !v1test && v2test {
45+
s.v = &events.APIGatewayV2HTTPRequest{}
46+
} else {
47+
return errors.New("unable to determine request version")
48+
}
49+
return json.Unmarshal(b, s.v)
50+
}
51+
52+
// Version1 returns the contained events.APIGatewayProxyRequest or nil
53+
func (s *SwitchableAPIGatewayRequest) Version1() *events.APIGatewayProxyRequest {
54+
switch v := s.v.(type) {
55+
case *events.APIGatewayProxyRequest:
56+
return v
57+
case events.APIGatewayProxyRequest:
58+
return &v
59+
}
60+
return nil
61+
}
62+
63+
// Version2 returns the contained events.APIGatewayV2HTTPRequest or nil
64+
func (s *SwitchableAPIGatewayRequest) Version2() *events.APIGatewayV2HTTPRequest {
65+
switch v := s.v.(type) {
66+
case *events.APIGatewayV2HTTPRequest:
67+
return v
68+
case events.APIGatewayV2HTTPRequest:
69+
return &v
70+
}
71+
return nil
72+
}

core/switchablerequest_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package core
2+
3+
import (
4+
"encoding/json"
5+
"github.com/aws/aws-lambda-go/events"
6+
. "github.com/onsi/ginkgo"
7+
. "github.com/onsi/gomega"
8+
)
9+
10+
var _ = Describe("SwitchableAPIGatewayRequest", func() {
11+
Context("Serialization", func() {
12+
It("v1 serialized okay", func() {
13+
e := NewSwitchableAPIGatewayRequestV1(&events.APIGatewayProxyRequest{
14+
MultiValueQueryStringParameters: map[string][]string{},
15+
})
16+
b, err := json.Marshal(e)
17+
Expect(err).To(BeNil())
18+
m := map[string]interface{}{}
19+
err = json.Unmarshal(b, &m)
20+
Expect(err).To(BeNil())
21+
Expect(m["multiValueQueryStringParameters"]).To(Equal(map[string]interface {}{}))
22+
Expect(m["body"]).To(Equal(""))
23+
})
24+
It("v2 serialized okay", func() {
25+
e := NewSwitchableAPIGatewayRequestV2(&events.APIGatewayV2HTTPRequest{})
26+
b, err := json.Marshal(e)
27+
Expect(err).To(BeNil())
28+
m := map[string]interface{}{}
29+
err = json.Unmarshal(b, &m)
30+
Expect(err).To(BeNil())
31+
Expect(m["rawQueryString"]).To(Equal(""))
32+
Expect(m["isBase64Encoded"]).To(Equal(false))
33+
})
34+
})
35+
Context("Deserialization", func() {
36+
It("v1 deserialized okay", func() {
37+
input := &events.APIGatewayProxyRequest{
38+
Body: "234",
39+
MultiValueQueryStringParameters: map[string][]string{
40+
"Test": []string{ "Value1", "Value2", },
41+
},
42+
}
43+
b, _ := json.Marshal(input)
44+
s := SwitchableAPIGatewayRequest{}
45+
err := s.UnmarshalJSON(b)
46+
Expect(err).To(BeNil())
47+
Expect(s.Version2()).To(BeNil())
48+
Expect(s.Version1()).To(BeEquivalentTo(input))
49+
})
50+
It("v2 deserialized okay", func() {
51+
input := &events.APIGatewayV2HTTPRequest{
52+
IsBase64Encoded: true,
53+
RawQueryString: "a=b&c=d",
54+
}
55+
b, _ := json.Marshal(input)
56+
s := SwitchableAPIGatewayRequest{}
57+
err := s.UnmarshalJSON(b)
58+
Expect(err).To(BeNil())
59+
Expect(s.Version1()).To(BeNil())
60+
Expect(s.Version2()).To(BeEquivalentTo(input))
61+
})
62+
})})
63+
64+
func getProxyRequestV2(path string, method string) events.APIGatewayV2HTTPRequest {
65+
return events.APIGatewayV2HTTPRequest{
66+
RequestContext: events.APIGatewayV2HTTPRequestContext{
67+
HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
68+
Path: path,
69+
Method: method,
70+
},
71+
},
72+
RawPath: path,
73+
}
74+
}

core/switchableresponse.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package core
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"github.com/aws/aws-lambda-go/events"
7+
)
8+
9+
// SwitchableAPIGatewayResponse is a container for an APIGatewayProxyResponse or an APIGatewayV2HTTPResponse object which
10+
// handles serialization and deserialization and switching between the entities based on the presence of fields in the
11+
// source JSON, multiValueQueryStringParameters for APIGatewayProxyResponse and rawQueryString for
12+
// APIGatewayV2HTTPResponse. It also provides some simple switching functions (wrapped type switching.)
13+
type SwitchableAPIGatewayResponse struct {
14+
v interface{}
15+
}
16+
17+
// NewSwitchableAPIGatewayResponseV1 creates a new SwitchableAPIGatewayResponse from APIGatewayProxyResponse
18+
func NewSwitchableAPIGatewayResponseV1(v *events.APIGatewayProxyResponse) *SwitchableAPIGatewayResponse {
19+
return &SwitchableAPIGatewayResponse{
20+
v: v,
21+
}
22+
}
23+
24+
// NewSwitchableAPIGatewayResponseV2 creates a new SwitchableAPIGatewayResponse from APIGatewayV2HTTPResponse
25+
func NewSwitchableAPIGatewayResponseV2(v *events.APIGatewayV2HTTPResponse) *SwitchableAPIGatewayResponse {
26+
return &SwitchableAPIGatewayResponse{
27+
v: v,
28+
}
29+
}
30+
31+
// MarshalJSON is a pass through serialization
32+
func (s *SwitchableAPIGatewayResponse) MarshalJSON() ([]byte, error) {
33+
return json.Marshal(s.v)
34+
}
35+
36+
// UnmarshalJSON is a switching serialization based on the presence of fields in the
37+
// source JSON, statusCode to verify that it's either APIGatewayProxyResponse or APIGatewayV2HTTPResponse and then
38+
// rawQueryString for to determine if it is APIGatewayV2HTTPResponse or not.
39+
func (s *SwitchableAPIGatewayResponse) UnmarshalJSON(b []byte) error {
40+
delta := map[string]json.RawMessage{}
41+
if err := json.Unmarshal(b, &delta); err != nil {
42+
return err
43+
}
44+
_, test := delta["statusCode"]
45+
_, v2test := delta["cookies"]
46+
s.v = nil
47+
if test && !v2test {
48+
s.v = &events.APIGatewayProxyResponse{}
49+
} else if test && v2test {
50+
s.v = &events.APIGatewayV2HTTPResponse{}
51+
} else {
52+
return errors.New("unable to determine response version")
53+
}
54+
return json.Unmarshal(b, s.v)
55+
}
56+
57+
// Version1 returns the contained events.APIGatewayProxyResponse or nil
58+
func (s *SwitchableAPIGatewayResponse) Version1() *events.APIGatewayProxyResponse {
59+
switch v := s.v.(type) {
60+
case *events.APIGatewayProxyResponse:
61+
return v
62+
case events.APIGatewayProxyResponse:
63+
return &v
64+
}
65+
return nil
66+
}
67+
68+
// Version2 returns the contained events.APIGatewayV2HTTPResponse or nil
69+
func (s *SwitchableAPIGatewayResponse) Version2() *events.APIGatewayV2HTTPResponse {
70+
switch v := s.v.(type) {
71+
case *events.APIGatewayV2HTTPResponse:
72+
return v
73+
case events.APIGatewayV2HTTPResponse:
74+
return &v
75+
}
76+
return nil
77+
}

0 commit comments

Comments
 (0)