Skip to content

Commit a8eb9e9

Browse files
author
Michael Weber
committed
WIP - TrackFilters
TrackFilters allow modifying what is stored in tracks. An example use case is filtering Authorization headers and cookies.
1 parent 6b478b0 commit a8eb9e9

File tree

4 files changed

+172
-1
lines changed

4 files changed

+172
-1
lines changed

cassette.go

+132
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,135 @@ func decompress(data []byte) ([]byte, error) {
511511

512512
return data, nil
513513
}
514+
515+
// Track Filters
516+
517+
// TrackFilter allows modifying information that is being stored in a track
518+
//
519+
// The filters receive a deep copy of the request and response, that can be modified.
520+
type TrackFilter func(*http.Request, *http.Response, error) (*http.Request, *http.Response)
521+
522+
// TrackFilters is a slice of TrackFilter
523+
type TrackFilters []TrackFilter
524+
525+
// TrackRequestDeleteHeaderKeys will delete one or more header keys on the request
526+
// before the request is matched against the cassette.
527+
func TrackRequestDeleteHeaderKeys(keys ...string) TrackFilter {
528+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
529+
for _, key := range keys {
530+
req.Header.Del(key)
531+
}
532+
return req, resp
533+
}
534+
}
535+
536+
// TrackResponseDeleteHeaderKeys will delete one or more header keys on the request
537+
// before the request is matched against the cassette.
538+
func TrackResponseDeleteHeaderKeys(keys ...string) TrackFilter {
539+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
540+
if resp == nil || resp.Header == nil {
541+
return req, resp
542+
}
543+
for _, key := range keys {
544+
resp.Header.Del(key)
545+
}
546+
return req, resp
547+
}
548+
}
549+
550+
// TrackRequestChangeBody will allows to change the body.
551+
// Supply a function that does input to output transformation.
552+
func TrackRequestChangeBody(fn func(b []byte) []byte) TrackFilter {
553+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
554+
body, err2 := readRequestBody(req) // XXX does too much, don't recreate body
555+
if err2 != nil {
556+
return req, resp
557+
}
558+
body = fn(body)
559+
req.Body = toReadCloser(body)
560+
req.ContentLength = int64(len(body))
561+
return req, resp
562+
}
563+
}
564+
565+
// TrackResponseChangeBody will allows to change the body.
566+
// Supply a function that does input to output transformation.
567+
func TrackResponseChangeBody(fn func(b []byte) []byte) TrackFilter {
568+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
569+
if resp == nil || resp.Header == nil {
570+
return req, resp
571+
}
572+
body, err2 := readResponseBody(resp) // XXX does too much, don't recreate body
573+
if err2 != nil {
574+
return req, resp
575+
}
576+
body = fn(body)
577+
resp.Body = toReadCloser(body)
578+
resp.ContentLength = int64(len(body))
579+
return req, resp
580+
}
581+
}
582+
583+
// OnMethod will return a new filter that will only apply 'r'
584+
// if the method of the request matches.
585+
// Original filter is unmodified.
586+
func (r TrackFilter) OnMethod(method string) TrackFilter {
587+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
588+
if req.Method != method {
589+
return req, resp
590+
}
591+
return r(req, resp, err)
592+
}
593+
}
594+
595+
// OnPath will return a track filter that will only apply 'r'
596+
// if the url string of the request matches the supplied regex.
597+
// Original filter is unmodified.
598+
func (r TrackFilter) OnPath(pathRegEx string) TrackFilter {
599+
if pathRegEx == "" {
600+
pathRegEx = ".*"
601+
}
602+
re := regexp.MustCompile(pathRegEx)
603+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
604+
if !re.MatchString(req.URL.String()) {
605+
return req, resp
606+
}
607+
return r(req, resp, err)
608+
}
609+
}
610+
611+
// OnStatus will return a Track filter that will only apply 'r' if the response status matches.
612+
// Original filter is unmodified.
613+
func (r TrackFilter) OnStatus(status int) TrackFilter {
614+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
615+
if resp == nil || resp.StatusCode != status {
616+
return req, resp
617+
}
618+
return r(req, resp, err)
619+
}
620+
}
621+
622+
// Add one or more filters at the end of the filter chain.
623+
func (r *TrackFilters) Add(filters ...TrackFilter) {
624+
v := *r
625+
v = append(v, filters...)
626+
*r = v
627+
}
628+
629+
// Prepend one or more filters before the current ones.
630+
func (r *TrackFilters) Prepend(filters ...TrackFilter) {
631+
src := *r
632+
dst := make(TrackFilters, 0, len(filters)+len(src))
633+
dst = append(dst, filters...)
634+
*r = append(dst, src...)
635+
}
636+
637+
// combined returns the filters as a single filter.
638+
func (r TrackFilters) combined() TrackFilter {
639+
return func(req *http.Request, resp *http.Response, err error) (*http.Request, *http.Response) {
640+
for _, filter := range r {
641+
req, resp = filter(req, resp, err)
642+
}
643+
return req, resp
644+
}
645+
}

govcr.go

+31
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ type VCRConfig struct {
3535
// Filter to run before a response is returned.
3636
ResponseFilters ResponseFilters
3737

38+
// Filter to run before storing a request/response pair in a track
39+
TrackFilters TrackFilters
40+
3841
// LongPlay will compress data on cassettes.
3942
LongPlay bool
4043
DisableRecording bool
@@ -85,6 +88,7 @@ func NewVCR(cassetteName string, vcrConfig *VCRConfig) *VCRControlPanel {
8588
Transport: vcrConfig.Client.Transport,
8689
RequestFilter: vcrConfig.RequestFilters.combined(),
8790
ResponseFilter: vcrConfig.ResponseFilters.combined(),
91+
TrackFilter: vcrConfig.TrackFilters.combined(),
8892
Logger: logger,
8993
CassettePath: vcrConfig.CassettePath,
9094
}
@@ -242,6 +246,33 @@ func readRequestBody(req *http.Request) ([]byte, error) {
242246
return bodyData, nil
243247
}
244248

249+
// copyResponse makes a copy an HTTP response.
250+
// It ensures that the original response Body stream is restored to its original state
251+
// and can be read from again.
252+
// TODO: should perform a deep copy of the TLS property as with URL
253+
func copyResponse(resp *http.Response) (*http.Response, error) {
254+
if resp == nil {
255+
return nil, nil
256+
}
257+
258+
// get a shallow copy
259+
copiedResp := *resp
260+
261+
copiedResp.Header = cloneHeader(resp.Header)
262+
263+
// deal with the Body
264+
bodyCopy, err := readResponseBody(resp)
265+
if err != nil {
266+
return nil, err
267+
}
268+
269+
// restore Body stream state
270+
resp.Body = toReadCloser(bodyCopy)
271+
copiedResp.Body = toReadCloser(bodyCopy)
272+
273+
return &copiedResp, nil
274+
}
275+
245276
// readResponseBody reads the Body data stream and restores its states.
246277
// It ensures the stream is restored to its original state and can be read from again.
247278
func readResponseBody(resp *http.Response) ([]byte, error) {

pcb.go

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type pcb struct {
1212
Transport http.RoundTripper
1313
RequestFilter RequestFilter
1414
ResponseFilter ResponseFilter
15+
TrackFilter TrackFilter
1516
Logger *log.Logger
1617
DisableRecording bool
1718
CassettePath string

vcr_transport.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func (t *vcrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
2424
t.PCB.Logger.Println(err)
2525
return nil, err
2626
}
27+
copiedReq, _ = t.PCB.TrackFilter(copiedReq, nil, nil)
2728

2829
// attempt to use a track from the cassette that matches
2930
// the request if one exists.
@@ -40,8 +41,14 @@ func (t *vcrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
4041

4142
if !t.PCB.DisableRecording {
4243
// the VCR is not in read-only mode so
44+
copiedResp, errResp := copyResponse(resp)
45+
if errResp != nil {
46+
t.PCB.Logger.Println(errResp)
47+
return nil, errResp
48+
}
49+
copiedReq, copiedResp = t.PCB.TrackFilter(copiedReq, copiedResp, err)
4350
t.PCB.Logger.Printf("INFO - Cassette '%s' - Recording new track for %s %s as %s %s\n", t.Cassette.Name, req.Method, req.URL.String(), copiedReq.Method, copiedReq.URL)
44-
if err := recordNewTrackToCassette(t.Cassette, copiedReq, resp, err); err != nil {
51+
if err := recordNewTrackToCassette(t.Cassette, copiedReq, copiedResp, err); err != nil {
4552
t.PCB.Logger.Println(err)
4653
}
4754
}

0 commit comments

Comments
 (0)