Skip to content

Commit f07ff4e

Browse files
committed
revproxy: add a caching HTTP reverse proxy
Add a new --revproxy flag to the serve command, that exports a minimal HTTP caching reverse proxy at the specified address. This proxy caches the results of successful GET requests for a set of hosts matched by --revproxy-targets.
1 parent 90ea5c1 commit f07ff4e

File tree

6 files changed

+590
-5
lines changed

6 files changed

+590
-5
lines changed

cmd/go-cache-plugin/commands.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/creachadair/taskgroup"
2626
"github.com/goproxy/goproxy"
2727
"github.com/tailscale/go-cache-plugin/internal/s3util"
28+
"github.com/tailscale/go-cache-plugin/revproxy"
2829
"github.com/tailscale/go-cache-plugin/s3cache"
2930
"github.com/tailscale/go-cache-plugin/s3proxy"
3031
"tailscale.com/tsweb"
@@ -120,6 +121,8 @@ func runDirect(env *command.Env) error {
120121
var serveFlags struct {
121122
Socket string `flag:"socket,default=$GOCACHE_SOCKET,Socket path (required)"`
122123
ModProxy string `flag:"modproxy,default=$GOCACHE_MODPROXY,Module proxy service address ([host]:port)"`
124+
RevProxy string `flag:"revproxy,default=$GOCACHE_REVPROXY,Reverse proxy service address ([host]:port)"`
125+
Targets string `flag:"revproxy-targets,default=$GOCACHE_REVPROXY_TARGETS,Reverse proxy targets (comma-separated)"`
123126
SumDB string `flag:"sumdb,default=$GOCACHE_SUMDB,SumDB servers to proxy for (comma-separated)"`
124127
}
125128

@@ -149,14 +152,15 @@ func runServe(env *command.Env) error {
149152

150153
ctx, cancel := signal.NotifyContext(env.Context(), syscall.SIGINT, syscall.SIGTERM)
151154
defer cancel()
152-
go func() {
155+
156+
var g taskgroup.Group
157+
g.Go(taskgroup.NoError(func() {
153158
<-ctx.Done()
154159
log.Printf("signal received, closing listener")
155160
lst.Close()
156-
}()
161+
}))
157162

158163
// If a module proxy is enabled, start it.
159-
var g taskgroup.Group
160164
if serveFlags.ModProxy != "" {
161165
modCachePath := filepath.Join(flags.CacheDir, "module")
162166
if err := os.MkdirAll(modCachePath, 0700); err != nil {
@@ -202,11 +206,48 @@ func runServe(env *command.Env) error {
202206
}
203207
g.Go(srv.ListenAndServe)
204208
vprintf("started module proxy at %q", serveFlags.ModProxy)
205-
go func() {
209+
g.Go(taskgroup.NoError(func() {
206210
<-ctx.Done()
207211
vprintf("signal received, stopping module proxy")
208212
srv.Shutdown(context.Background())
209-
}()
213+
}))
214+
}
215+
216+
// If a reverse proxy is enabled, start it.
217+
if serveFlags.RevProxy != "" {
218+
if serveFlags.Targets == "" {
219+
return env.Usagef("must provide --revproxy-targets when --revproxy is set")
220+
}
221+
revCachePath := filepath.Join(flags.CacheDir, "revproxy")
222+
if err := os.MkdirAll(revCachePath, 0700); err != nil {
223+
lst.Close()
224+
return fmt.Errorf("create revproxy cache: %w", err)
225+
}
226+
proxy := &revproxy.Server{
227+
Targets: strings.Split(serveFlags.Targets, ","),
228+
Local: revCachePath,
229+
S3Client: s3c,
230+
KeyPrefix: path.Join(flags.KeyPrefix, "revproxy"),
231+
Logf: vprintf,
232+
}
233+
expvar.Publish("revcache", proxy.Metrics())
234+
235+
mux := http.NewServeMux()
236+
mux.Handle("/", proxy)
237+
if serveFlags.ModProxy == "" {
238+
tsweb.Debugger(mux) // attach debugger if --modproxy doesn't already have it
239+
}
240+
srv := &http.Server{
241+
Addr: serveFlags.RevProxy,
242+
Handler: mux,
243+
}
244+
g.Go(srv.ListenAndServe)
245+
vprintf("started reverse proxy at %q for %s", serveFlags.RevProxy, strings.Join(proxy.Targets, ", "))
246+
g.Go(taskgroup.NoError(func() {
247+
<-ctx.Done()
248+
vprintf("signal received, stopping reverse proxy")
249+
srv.Shutdown(context.Background())
250+
}))
210251
}
211252

212253
for {

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ require (
4040
github.com/beorn7/perks v1.0.1 // indirect
4141
github.com/cespare/xxhash/v2 v2.2.0 // indirect
4242
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect
43+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
4344
github.com/google/uuid v1.6.0 // indirect
4445
github.com/prometheus/client_golang v1.19.1 // indirect
4546
github.com/prometheus/client_model v0.5.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
6060
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
6161
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg=
6262
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
63+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
64+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
6365
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
6466
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
6567
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=

revproxy/cache.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package revproxy
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"io/fs"
10+
"net/http"
11+
"os"
12+
"path/filepath"
13+
"strings"
14+
"time"
15+
16+
"github.com/creachadair/atomicfile"
17+
"github.com/creachadair/taskgroup"
18+
)
19+
20+
// cacheLoadLocal reads cached headers and body from the local cache.
21+
func (s *Server) cacheLoadLocal(hash string) ([]byte, http.Header, error) {
22+
data, err := os.ReadFile(s.makePath(hash))
23+
if err != nil {
24+
return nil, nil, err
25+
}
26+
return parseCacheObject(data)
27+
}
28+
29+
// cacheStoreLocal writes the contents of body to the local cache.
30+
//
31+
// The file format is a plain-text section at the top recording a subset of the
32+
// response headers, followed by "\n\n", followed by the response body.
33+
func (s *Server) cacheStoreLocal(hash string, hdr http.Header, body []byte) error {
34+
path := s.makePath(hash)
35+
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
36+
return err
37+
}
38+
return atomicfile.Tx(s.makePath(hash), 0600, func(f *atomicfile.File) error {
39+
return writeCacheObject(f, hdr, body)
40+
})
41+
}
42+
43+
// cacheLoadS3 reads cached headers and body from the remote S3 cache.
44+
func (s *Server) cacheLoadS3(ctx context.Context, hash string) ([]byte, http.Header, error) {
45+
data, err := s.S3Client.GetData(ctx, s.makeKey(hash))
46+
if err != nil {
47+
return nil, nil, err
48+
}
49+
return parseCacheObject(data)
50+
}
51+
52+
// cacheStoreS3 returns a task that writes the contents of body to the remote
53+
// S3 cache.
54+
func (s *Server) cacheStoreS3(hash string, hdr http.Header, body []byte) taskgroup.Task {
55+
var buf bytes.Buffer
56+
writeCacheObject(&buf, hdr, body)
57+
nb := buf.Len()
58+
return func() error {
59+
sctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
60+
defer cancel()
61+
62+
if err := s.S3Client.Put(sctx, s.makeKey(hash), &buf); err != nil {
63+
s.logf("[s3] put %q failed: %v", hash, err)
64+
s.rspPushError.Add(1)
65+
} else {
66+
s.rspPush.Add(1)
67+
s.rspPushBytes.Add(int64(nb))
68+
}
69+
return nil
70+
}
71+
}
72+
73+
// cacheLoadMemory reads cached headers and body from the memory cache.
74+
func (s *Server) cacheLoadMemory(hash string) ([]byte, http.Header, error) {
75+
s.mcacheMu.Lock()
76+
defer s.mcacheMu.Unlock()
77+
v, ok := s.mcache.Get(hash)
78+
if !ok {
79+
return nil, nil, fs.ErrNotExist
80+
}
81+
entry := v.(memCacheEntry)
82+
if time.Now().After(entry.expires) {
83+
s.mcache.Remove(hash)
84+
return nil, nil, errors.New("entry expired")
85+
}
86+
return entry.body, entry.header, nil
87+
}
88+
89+
// cacheStoreMemory writes the contents of body to the memory cache.
90+
func (s *Server) cacheStoreMemory(hash string, maxAge time.Duration, hdr http.Header, body []byte) {
91+
s.mcacheMu.Lock()
92+
defer s.mcacheMu.Unlock()
93+
s.mcache.Add(hash, memCacheEntry{
94+
header: hdr,
95+
body: body,
96+
expires: time.Now().Add(maxAge),
97+
})
98+
}
99+
100+
// parseCacheDbject parses cached object data to extract the body and headers.
101+
func parseCacheObject(data []byte) ([]byte, http.Header, error) {
102+
hdr, rest, ok := bytes.Cut(data, []byte("\n\n"))
103+
if !ok {
104+
return nil, nil, errors.New("invalid cache object: missing header")
105+
}
106+
h := make(http.Header)
107+
for _, line := range strings.Split(string(hdr), "\n") {
108+
name, value, ok := strings.Cut(line, ": ")
109+
if ok {
110+
h.Add(name, value)
111+
}
112+
}
113+
return rest, h, nil
114+
}
115+
116+
// writeCacheObject writes the specified response data into a cache object at w.
117+
func writeCacheObject(w io.Writer, h http.Header, body []byte) error {
118+
hprintf(w, h, "Content-Type", "application/octet-stream")
119+
hprintf(w, h, "Date", "")
120+
hprintf(w, h, "Etag", "")
121+
fmt.Fprint(w, "\n")
122+
_, err := w.Write(body)
123+
return err
124+
}
125+
126+
func hprintf(w io.Writer, h http.Header, name, fallback string) {
127+
if v := h.Get(name); v != "" {
128+
fmt.Fprintf(w, "%s: %s\n", name, v)
129+
} else if fallback != "" {
130+
fmt.Fprintf(w, "%s: %s\n", name, fallback)
131+
}
132+
}
133+
134+
// setXCacheInfo adds cache-specific headers to h.
135+
func setXCacheInfo(h http.Header, result, hash string) {
136+
h.Set("X-Cache", result)
137+
if hash != "" {
138+
h.Set("X-Cache-Id", hash[:12])
139+
}
140+
}
141+
142+
// memCacheEntry is the format of entries in the memory cache.
143+
type memCacheEntry struct {
144+
header http.Header
145+
body []byte
146+
expires time.Time
147+
}

revproxy/internal_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package revproxy
2+
3+
import (
4+
"net/url"
5+
"testing"
6+
)
7+
8+
func TestCheckTarget(t *testing.T) {
9+
s := &Server{
10+
Targets: []string{"foo.com", "*.bar.com"},
11+
}
12+
tests := []struct {
13+
input string
14+
want bool
15+
}{
16+
{"", false},
17+
{"nonesuch.org", false},
18+
{"foo.com", true},
19+
{"other.foo.com", false},
20+
{"bar.com", true},
21+
{"other.bar.com", true},
22+
{"some.other.bar.com", true},
23+
}
24+
for _, tc := range tests {
25+
u := &url.URL{Host: "localhost", Path: tc.input}
26+
if got := s.checkTarget(u); got != tc.want {
27+
t.Errorf("Check %q: got %v, want %v", tc.input, got, tc.want)
28+
}
29+
}
30+
}

0 commit comments

Comments
 (0)