Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 3 additions & 56 deletions middleware/webrpc.go
Original file line number Diff line number Diff line change
@@ -1,65 +1,12 @@
package middleware

import (
"log/slog"
"net/http"
"strconv"
"strings"

"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httplog/v3"
"github.com/go-chi/metrics"
"github.com/0xsequence/go-libs/middleware/webrpc"
)

var (
// Total number of requests versioned with Webrpc header.
requestsTotal = metrics.CounterWith[labels]("webrpc_requests_total", "Total number of webrpc client requests.")
)

type labels struct {
Gen string `label:"gen"`
Schema string `label:"schema"`
Status string `label:"status"`
}

// WebrpcTelemetry is a middleware that extracts webrpc client information from request headers,
// logs it to request log for traceability, and collects usage metrics for API analytics.
// Deprecated: Use webrpc.Telemetry(webrpc.Opts{Origin: false}) instead.
func WebrpcTelemetry(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
webrpcHeader := r.Header.Get("Webrpc")
if webrpcHeader == "" {
next.ServeHTTP(w, r)
return
}

// Webrpc: webrpc@v0.25.1;gen-golang@v0.19.0;marketplace-api@v25.9.1;...
versions := strings.Split(webrpcHeader, ";")
if len(versions) < 3 {
next.ServeHTTP(w, r)
return
}

webrpcGen, _, _ := strings.Cut(versions[1], "@") // gen-golang@v0.19.0 -> gen-golang
webrpcSchema := versions[2] // marketplace-api@v25.9.1

httplog.SetAttrs(r.Context(),
slog.String("webrpcGen", webrpcGen),
slog.String("webrpcSchema", webrpcSchema),
)

ww, ok := w.(middleware.WrapResponseWriter)
if !ok {
ww = middleware.NewWrapResponseWriter(w, r.ProtoMajor)
}

defer func() {
requestsTotal.Inc(labels{
Gen: webrpcGen,
Schema: webrpcSchema,
Status: strconv.Itoa(ww.Status()),
})
}()

next.ServeHTTP(ww, r)
})
return webrpc.Telemetry(webrpc.Opts{Origin: false})(next)
}
85 changes: 85 additions & 0 deletions middleware/webrpc/webrpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package webrpc

import (
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"

"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httplog/v3"
"github.com/go-chi/metrics"
)

// Total number of requests.
var requestsTotal = metrics.CounterWith[labels]("webrpc_requests_total", "Total number of webrpc requests.")

type labels struct {
Gen string `label:"gen"`
Schema string `label:"schema"`
Status string `label:"status"`
Origin string `label:"origin"`
}

type Opts struct {
// Track origin label in metrics.
// NOTE: Cardinality grows with the number of unique origin headers.
Origin bool

Skip func(r *http.Request) bool
}

// Telemetry is a middleware that extracts webrpc client information from request headers,
// logs it to request log for traceability, and collects usage metrics for API analytics.
func Telemetry(opts Opts) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if opts.Skip != nil && opts.Skip(r) {
next.ServeHTTP(w, r)
return
}

ww, ok := w.(middleware.WrapResponseWriter)
if !ok {
ww = middleware.NewWrapResponseWriter(w, r.ProtoMajor)
}

var labels labels
defer func() {
labels.Status = strconv.Itoa(ww.Status())
requestsTotal.Inc(labels)
}()

webrpcGen, webrpcSchema := parseWebrpcHeader(r.Header.Get("Webrpc"))
if webrpcSchema != "" {
labels.Gen = webrpcGen
labels.Schema = webrpcSchema
httplog.SetAttrs(r.Context(),
slog.String("webrpcGen", webrpcGen),
slog.String("webrpcSchema", webrpcSchema),
)
}

if opts.Origin {
if origin := strings.TrimSpace(r.Header.Get("Origin")); origin != "" && origin != "null" {
if u, err := url.Parse(origin); err == nil && u.Scheme != "" && u.Host != "" {
labels.Origin = u.Host
}
}
}

next.ServeHTTP(ww, r)
})
}
}

func parseWebrpcHeader(header string) (string, string) {
versions := strings.Split(header, ";")
if len(versions) < 3 {
return "", ""
}
webrpcGen, _, _ := strings.Cut(versions[1], "@") // gen-golang@v0.19.0 -> gen-golang
webrpcSchema := versions[2] // marketplace-api@v25.9.1
return webrpcGen, webrpcSchema
}
152 changes: 152 additions & 0 deletions middleware/webrpc/webrpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package webrpc_test

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
"github.com/go-chi/metrics"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"github.com/test-go/testify/assert"

"github.com/0xsequence/go-libs/middleware/webrpc"
)

func TestWebrpcTelemetry(t *testing.T) {
t.Run("no origin label", func(t *testing.T) {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != "OPTIONS"
},
}))
r.Use(webrpc.Telemetry(webrpc.Opts{}))
r.Get("/ok", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
})
r.Handle("/metrics", metrics.Handler())

req := httptest.NewRequest(http.MethodGet, "/ok", nil)
req.Header.Set("Origin", "https://disabled-test-telemetry.example/path?x=y#z")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

assert.Equal(t, 422, rr.Code)

mfs := scrapeMetrics(t, r)
mf := mfs["webrpc_requests_total"]
assert.NotNil(t, mf)
assert.True(t, metricHasLabels(mf, map[string]string{"status": "422", "origin": ""}))
assert.False(t, metricHasLabels(mf, map[string]string{"status": "422", "origin": "disabled-test-telemetry.example"}))
})

t.Run("origin label with host only", func(t *testing.T) {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != "OPTIONS"
},
}))
r.Use(webrpc.Telemetry(webrpc.Opts{Origin: true}))
r.Get("/ok", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
})
r.Handle("/metrics", metrics.Handler())

req := httptest.NewRequest(http.MethodGet, "/ok", nil)
req.Header.Set("Origin", "https://enabled-test-telemetry.example/some/path?x=y#z")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

assert.Equal(t, 422, rr.Code)

mfs := scrapeMetrics(t, r)
mf := mfs["webrpc_requests_total"]
assert.NotNil(t, mf)
assert.True(t, metricHasLabels(mf, map[string]string{"status": "422", "origin": "enabled-test-telemetry.example"}))
assert.False(t, metricHasLabels(mf, map[string]string{"status": "422", "origin": "enabled-test-telemetry.example/some/path?x=y#z"}))
})

t.Run("OPTIONS preflight can be skipped", func(t *testing.T) {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != "OPTIONS"
},
}))
r.Use(webrpc.Telemetry(webrpc.Opts{
Origin: true,
Skip: func(r *http.Request) bool {
// Typical CORS preflight signal; avoids dropping legitimate OPTIONS.
return r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != ""
},
}))
r.Get("/ok", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
})
r.Handle("/metrics", metrics.Handler())

req := httptest.NewRequest(http.MethodOptions, "/ok", nil)
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
req.Header.Set("Origin", "https://no-header.example/")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

// Handler is not invoked because chi won't route OPTIONS to GET; status isn't important here.

mfs := scrapeMetrics(t, r)
mf := mfs["webrpc_requests_total"]
// We skipped this request entirely, so there should be no series with origin/no-header.example.
assert.False(t, metricHasLabels(mf, map[string]string{"origin": "no-header.example"}))
})
}

func scrapeMetrics(t *testing.T, r http.Handler) map[string]*dto.MetricFamily {
t.Helper()

req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)

var p expfmt.TextParser
mfs, err := p.TextToMetricFamilies(bytes.NewReader(rr.Body.Bytes()))
assert.NoError(t, err)

return mfs
}

func metricHasLabels(mf *dto.MetricFamily, labels map[string]string) bool {
if mf == nil {
return false
}
for _, m := range mf.GetMetric() {
ok := true
for wantName, wantValue := range labels {
found := false
for _, lp := range m.GetLabel() {
if lp.GetName() == wantName && lp.GetValue() == wantValue {
found = true
break
}
}
if !found {
ok = false
break
}
}
if ok {
return true
}
}
return false
}