Skip to content

revproxy: add a caching HTTP reverse proxy #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions cmd/go-cache-plugin/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/creachadair/taskgroup"
"github.com/goproxy/goproxy"
"github.com/tailscale/go-cache-plugin/internal/s3util"
"github.com/tailscale/go-cache-plugin/revproxy"
"github.com/tailscale/go-cache-plugin/s3cache"
"github.com/tailscale/go-cache-plugin/s3proxy"
"tailscale.com/tsweb"
Expand Down Expand Up @@ -120,6 +121,8 @@ func runDirect(env *command.Env) error {
var serveFlags struct {
Socket string `flag:"socket,default=$GOCACHE_SOCKET,Socket path (required)"`
ModProxy string `flag:"modproxy,default=$GOCACHE_MODPROXY,Module proxy service address ([host]:port)"`
RevProxy string `flag:"revproxy,default=$GOCACHE_REVPROXY,Reverse proxy service address ([host]:port)"`
Targets string `flag:"revproxy-targets,default=$GOCACHE_REVPROXY_TARGETS,Reverse proxy targets (comma-separated)"`
SumDB string `flag:"sumdb,default=$GOCACHE_SUMDB,SumDB servers to proxy for (comma-separated)"`
}

Expand Down Expand Up @@ -149,14 +152,15 @@ func runServe(env *command.Env) error {

ctx, cancel := signal.NotifyContext(env.Context(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
go func() {

var g taskgroup.Group
g.Go(taskgroup.NoError(func() {
<-ctx.Done()
log.Printf("signal received, closing listener")
lst.Close()
}()
}))

// If a module proxy is enabled, start it.
var g taskgroup.Group
if serveFlags.ModProxy != "" {
modCachePath := filepath.Join(flags.CacheDir, "module")
if err := os.MkdirAll(modCachePath, 0700); err != nil {
Expand Down Expand Up @@ -202,11 +206,48 @@ func runServe(env *command.Env) error {
}
g.Go(srv.ListenAndServe)
vprintf("started module proxy at %q", serveFlags.ModProxy)
go func() {
g.Go(taskgroup.NoError(func() {
<-ctx.Done()
vprintf("signal received, stopping module proxy")
srv.Shutdown(context.Background())
}()
}))
}

// If a reverse proxy is enabled, start it.
if serveFlags.RevProxy != "" {
if serveFlags.Targets == "" {
return env.Usagef("must provide --revproxy-targets when --revproxy is set")
}
revCachePath := filepath.Join(flags.CacheDir, "revproxy")
if err := os.MkdirAll(revCachePath, 0700); err != nil {
lst.Close()
return fmt.Errorf("create revproxy cache: %w", err)
}
proxy := &revproxy.Server{
Targets: strings.Split(serveFlags.Targets, ","),
Local: revCachePath,
S3Client: s3c,
KeyPrefix: path.Join(flags.KeyPrefix, "revproxy"),
Logf: vprintf,
}
expvar.Publish("revcache", proxy.Metrics())

mux := http.NewServeMux()
mux.Handle("/", proxy)
if serveFlags.ModProxy == "" {
tsweb.Debugger(mux) // attach debugger if --modproxy doesn't already have it
}
srv := &http.Server{
Addr: serveFlags.RevProxy,
Handler: mux,
}
g.Go(srv.ListenAndServe)
vprintf("started reverse proxy at %q for %s", serveFlags.RevProxy, strings.Join(proxy.Targets, ", "))
g.Go(taskgroup.NoError(func() {
<-ctx.Done()
vprintf("signal received, stopping reverse proxy")
srv.Shutdown(context.Background())
}))
}

for {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/prometheus/client_golang v1.19.1 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg=
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
Expand Down
147 changes: 147 additions & 0 deletions revproxy/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package revproxy

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"time"

"github.com/creachadair/atomicfile"
"github.com/creachadair/taskgroup"
)

// cacheLoadLocal reads cached headers and body from the local cache.
func (s *Server) cacheLoadLocal(hash string) ([]byte, http.Header, error) {
data, err := os.ReadFile(s.makePath(hash))
if err != nil {
return nil, nil, err
}
return parseCacheObject(data)
}

// cacheStoreLocal writes the contents of body to the local cache.
//
// The file format is a plain-text section at the top recording a subset of the
// response headers, followed by "\n\n", followed by the response body.
func (s *Server) cacheStoreLocal(hash string, hdr http.Header, body []byte) error {
path := s.makePath(hash)
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
return atomicfile.Tx(s.makePath(hash), 0600, func(f *atomicfile.File) error {
return writeCacheObject(f, hdr, body)
})
}

// cacheLoadS3 reads cached headers and body from the remote S3 cache.
func (s *Server) cacheLoadS3(ctx context.Context, hash string) ([]byte, http.Header, error) {
data, err := s.S3Client.GetData(ctx, s.makeKey(hash))
if err != nil {
return nil, nil, err
}
return parseCacheObject(data)
}

// cacheStoreS3 returns a task that writes the contents of body to the remote
// S3 cache.
func (s *Server) cacheStoreS3(hash string, hdr http.Header, body []byte) taskgroup.Task {
var buf bytes.Buffer
writeCacheObject(&buf, hdr, body)
nb := buf.Len()
return func() error {
sctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

if err := s.S3Client.Put(sctx, s.makeKey(hash), &buf); err != nil {
s.logf("[s3] put %q failed: %v", hash, err)
s.rspPushError.Add(1)
} else {
s.rspPush.Add(1)
s.rspPushBytes.Add(int64(nb))
}
return nil
}
}

// cacheLoadMemory reads cached headers and body from the memory cache.
func (s *Server) cacheLoadMemory(hash string) ([]byte, http.Header, error) {
s.mcacheMu.Lock()
defer s.mcacheMu.Unlock()
v, ok := s.mcache.Get(hash)
if !ok {
return nil, nil, fs.ErrNotExist
}
entry := v.(memCacheEntry)
if time.Now().After(entry.expires) {
s.mcache.Remove(hash)
return nil, nil, errors.New("entry expired")
}
return entry.body, entry.header, nil
}

// cacheStoreMemory writes the contents of body to the memory cache.
func (s *Server) cacheStoreMemory(hash string, maxAge time.Duration, hdr http.Header, body []byte) {
s.mcacheMu.Lock()
defer s.mcacheMu.Unlock()
s.mcache.Add(hash, memCacheEntry{
header: hdr,
body: body,
expires: time.Now().Add(maxAge),
})
}

// parseCacheDbject parses cached object data to extract the body and headers.
func parseCacheObject(data []byte) ([]byte, http.Header, error) {
hdr, rest, ok := bytes.Cut(data, []byte("\n\n"))
if !ok {
return nil, nil, errors.New("invalid cache object: missing header")
}
h := make(http.Header)
for _, line := range strings.Split(string(hdr), "\n") {
name, value, ok := strings.Cut(line, ": ")
if ok {
h.Add(name, value)
}
}
return rest, h, nil
}

// writeCacheObject writes the specified response data into a cache object at w.
func writeCacheObject(w io.Writer, h http.Header, body []byte) error {
hprintf(w, h, "Content-Type", "application/octet-stream")
hprintf(w, h, "Date", "")
hprintf(w, h, "Etag", "")
fmt.Fprint(w, "\n")
_, err := w.Write(body)
return err
}

func hprintf(w io.Writer, h http.Header, name, fallback string) {
if v := h.Get(name); v != "" {
fmt.Fprintf(w, "%s: %s\n", name, v)
} else if fallback != "" {
fmt.Fprintf(w, "%s: %s\n", name, fallback)
}
}

// setXCacheInfo adds cache-specific headers to h.
func setXCacheInfo(h http.Header, result, hash string) {
h.Set("X-Cache", result)
if hash != "" {
h.Set("X-Cache-Id", hash[:12])
}
}

// memCacheEntry is the format of entries in the memory cache.
type memCacheEntry struct {
header http.Header
body []byte
expires time.Time
}
30 changes: 30 additions & 0 deletions revproxy/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package revproxy

import (
"net/url"
"testing"
)

func TestCheckTarget(t *testing.T) {
s := &Server{
Targets: []string{"foo.com", "*.bar.com"},
}
tests := []struct {
input string
want bool
}{
{"", false},
{"nonesuch.org", false},
{"foo.com", true},
{"other.foo.com", false},
{"bar.com", true},
{"other.bar.com", true},
{"some.other.bar.com", true},
}
for _, tc := range tests {
u := &url.URL{Host: "localhost", Path: tc.input}
if got := s.checkTarget(u); got != tc.want {
t.Errorf("Check %q: got %v, want %v", tc.input, got, tc.want)
}
}
}
Loading
Loading