Skip to content

Commit 90fb81c

Browse files
feat: [FFM-12184]: add ability to rotate auth secrets (#385)
* feat: [FFM-12184]: allow legacy auth keys to decode tokens * feat: [FFM-12184]: add prometheus metrics for auth secret vs legacy secret decode metrics
1 parent c3ec00d commit 90fb81c

File tree

6 files changed

+204
-17
lines changed

6 files changed

+204
-17
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ $(GOBIN)/golangci-lint:
161161
# Install goimports to format code
162162
$(GOBIN)/goimports:
163163
@echo "🔘 Installing goimports ... (`date '+%H:%M:%S'`)"
164-
@go install golang.org/x/tools/cmd/goimports@latest
164+
@go install golang.org/x/tools/cmd/goimports@v0.30.0
165165

166166
# Install gocov to parse code coverage
167167
$(GOBIN)/gocov:

cmd/ff-proxy/main.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ var (
5252
clientService string
5353
metricService string
5454
authSecret string
55+
legacyAuthSecrets legacySecrets
5556
metricPostDuration int
5657
heartbeatInterval int
5758
generateOfflineConfig bool
@@ -88,13 +89,30 @@ var (
8889
andRules bool
8990
)
9091

92+
// legacySecrets implements the flag.Value interface and allows us to pass a comma separated
93+
// list of legacy secrets e.g. -legacy-secrets mysecret,mysecret2
94+
type legacySecrets []string
95+
96+
func (l *legacySecrets) String() string {
97+
return strings.Join(*l, ",")
98+
}
99+
100+
func (l *legacySecrets) Set(value string) error {
101+
ss := strings.Split(value, ",")
102+
for _, s := range ss {
103+
*l = append(*l, s)
104+
}
105+
return nil
106+
}
107+
91108
// Environment Variables
92109
const (
93110
// Service Config
94111
proxyKeyEnv = "PROXY_KEY"
95112
clientServiceEnv = "CLIENT_SERVICE"
96113
metricServiceEnv = "METRIC_SERVICE"
97114
authSecretEnv = "AUTH_SECRET"
115+
legacySecretsEnv = "LEGACY_SECRETS"
98116
metricPostDurationEnv = "METRIC_POST_DURATION"
99117
heartbeatIntervalEnv = "HEARTBEAT_INTERVAL"
100118
generateOfflineConfigEnv = "GENERATE_OFFLINE_CONFIG"
@@ -138,6 +156,7 @@ const (
138156
clientServiceFlag = "client-service"
139157
metricServiceFlag = "metric-service"
140158
authSecretFlag = "auth-secret"
159+
legacySecretsFlag = "legacy-secrets"
141160
metricPostDurationFlag = "metric-post-duration"
142161
heartbeatIntervalFlag = "heartbeat-interval"
143162
generateOfflineConfigFlag = "generate-offline-config"
@@ -181,6 +200,7 @@ func init() {
181200
flag.StringVar(&clientService, clientServiceFlag, "https://config.ff.harness.io/api/1.0", "the url of the ff client service")
182201
flag.StringVar(&metricService, metricServiceFlag, "https://events.ff.harness.io/api/1.0", "the url of the ff metric service")
183202
flag.StringVar(&authSecret, authSecretFlag, "secret", "the secret used for signing auth tokens")
203+
flag.Var(&legacyAuthSecrets, legacySecretsFlag, "legacy secrets used to decode auth tokens. If rotating a secret you can place the old auth-secret in here so old tokens will still remain valid while new tokens are issued using the new auth secret")
184204
flag.IntVar(&metricPostDuration, metricPostDurationFlag, 60, "How often in seconds the proxy posts metrics to Harness. Set to 0 to disable.")
185205
flag.IntVar(&heartbeatInterval, heartbeatIntervalFlag, 60, "How often in seconds the proxy polls pings it's health function. Set to 0 to disable.")
186206
flag.BoolVar(&generateOfflineConfig, generateOfflineConfigFlag, false, "if true the proxy will produce offline config in the /config directory then terminate")
@@ -223,6 +243,7 @@ func init() {
223243
clientServiceEnv: clientServiceFlag,
224244
metricServiceEnv: metricServiceFlag,
225245
authSecretEnv: authSecretFlag,
246+
legacySecretsEnv: legacySecretsFlag,
226247
redisAddrEnv: redisAddressFlag,
227248
redisPasswordEnv: redisPasswordFlag,
228249
redisUsernameEnv: redisUsernameFlag,
@@ -561,6 +582,10 @@ func main() {
561582
})
562583

563584
// Configure endpoints and server
585+
legacySecretsBytes := make([][]byte, 0)
586+
for _, s := range legacyAuthSecrets {
587+
legacySecretsBytes = append(legacySecretsBytes, []byte(s))
588+
}
564589
endpoints := transport.NewEndpoints(service)
565590
server := transport.NewHTTPServer(port, endpoints, logger, tlsEnabled, tlsCert, tlsKey)
566591
server.Use(
@@ -569,7 +594,7 @@ func main() {
569594
middleware.NewCorsMiddleware(),
570595
middleware.NewEchoRequestIDMiddleware(),
571596
middleware.NewEchoLoggingMiddleware(logger),
572-
middleware.NewEchoAuthMiddleware(logger, authRepo, []byte(authSecret), bypassAuth),
597+
middleware.NewEchoAuthMiddleware(logger, authRepo, []byte(authSecret), legacySecretsBytes, bypassAuth, promReg),
573598
middleware.ValidateEnvironment(bypassAuth),
574599
)
575600

middleware/middleware.go

+35-14
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ import (
1111

1212
"github.com/golang-jwt/jwt"
1313
"github.com/google/uuid"
14+
"github.com/harness/ff-proxy/v2/domain"
15+
"github.com/harness/ff-proxy/v2/log"
1416
"github.com/labstack/echo/v4"
1517
"github.com/labstack/echo/v4/middleware"
1618
"github.com/prometheus/client_golang/prometheus"
17-
18-
"github.com/harness/ff-proxy/v2/domain"
19-
"github.com/harness/ff-proxy/v2/log"
2019
)
2120

2221
type requestContextKey string
@@ -52,29 +51,51 @@ func NewEchoLoggingMiddleware(l log.Logger) echo.MiddlewareFunc {
5251
})
5352
}
5453

54+
// validateToken attempts to validate a JWT token with the given secret
55+
func validateToken(tokenStr string, secret []byte) (*jwt.Token, error) {
56+
return jwt.ParseWithClaims(tokenStr, &domain.Claims{}, func(t *jwt.Token) (interface{}, error) {
57+
return secret, nil
58+
})
59+
}
60+
5561
// NewEchoAuthMiddleware returns an echo middleware that checks if auth headers
5662
// are valid
57-
func NewEchoAuthMiddleware(logger log.Logger, authRepo keyLookUp, secret []byte, bypassAuth bool) echo.MiddlewareFunc {
63+
// nolint:cyclop
64+
func NewEchoAuthMiddleware(logger log.Logger, authRepo keyLookUp, secret []byte, legacySecrets [][]byte, bypassAuth bool, reg *prometheus.Registry) echo.MiddlewareFunc {
65+
metrics := newPrometheusAuth(reg)
66+
5867
return middleware.JWTWithConfig(middleware.JWTConfig{
5968
AuthScheme: "Bearer",
6069
TokenLookup: "header:Authorization",
6170
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
6271
if auth == "" {
63-
return nil, errors.New("token was empty")
72+
return nil, errors.New("authorization token is required")
6473
}
6574

66-
token, err := jwt.ParseWithClaims(auth, &domain.Claims{}, func(t *jwt.Token) (interface{}, error) {
67-
return secret, nil
68-
})
69-
if err != nil {
70-
return nil, err
75+
// First try with current secret
76+
token, err := validateToken(auth, secret)
77+
if err == nil {
78+
metrics.currentSecretTokens.Inc()
79+
if claims, ok := token.Claims.(*domain.Claims); ok && token.Valid && isKeyInCache(c.Request().Context(), logger, authRepo, claims) {
80+
c.Set(tokenClaims.String(), claims)
81+
return nil, nil
82+
}
83+
return nil, errors.New("invalid token")
7184
}
7285

73-
if claims, ok := token.Claims.(*domain.Claims); ok && token.Valid && isKeyInCache(c.Request().Context(), logger, authRepo, claims) {
74-
c.Set(tokenClaims.String(), claims)
75-
return nil, nil
86+
// If current secret fails, try legacy secrets
87+
for _, legacySecret := range legacySecrets {
88+
if token, err = validateToken(auth, legacySecret); err == nil {
89+
metrics.legacySecretTokens.Inc()
90+
if claims, ok := token.Claims.(*domain.Claims); ok && token.Valid && isKeyInCache(c.Request().Context(), logger, authRepo, claims) {
91+
c.Set(tokenClaims.String(), claims)
92+
return nil, nil
93+
}
94+
return nil, errors.New("invalid token")
95+
}
7696
}
77-
return nil, errors.New("invalid token")
97+
98+
return nil, err
7899
},
79100
Skipper: func(c echo.Context) bool {
80101
if bypassAuth {

middleware/middleware_test.go

+114
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
package middleware
22

33
import (
4+
"context"
45
"net/http"
56
"net/http/httptest"
67
"net/url"
78
"strings"
89
"testing"
910

11+
"github.com/prometheus/client_golang/prometheus"
12+
13+
"github.com/golang-jwt/jwt"
14+
1015
"github.com/harness/ff-proxy/v2/domain"
1116
"github.com/labstack/echo/v4"
1217
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1319
)
1420

1521
func TestAllowQuerySemicolons(t *testing.T) {
@@ -216,3 +222,111 @@ func TestSkipper(t *testing.T) {
216222
})
217223
}
218224
}
225+
226+
type mockKeyLookup struct {
227+
shouldExist bool
228+
}
229+
230+
func (m *mockKeyLookup) Get(_ context.Context, _ domain.AuthAPIKey) (string, bool, error) {
231+
return "", m.shouldExist, nil
232+
}
233+
234+
func TestNewEchoAuthMiddleware(t *testing.T) {
235+
// Create test secrets
236+
currentSecret := []byte("current-secret")
237+
legacySecret := []byte("legacy-secret")
238+
legacySecrets := [][]byte{legacySecret}
239+
240+
// Create test claims
241+
validClaims := &domain.Claims{
242+
APIKey: "test-key",
243+
}
244+
245+
// Generate test tokens
246+
validToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, validClaims).SignedString(currentSecret)
247+
require.NoError(t, err)
248+
legacyToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, validClaims).SignedString(legacySecret)
249+
require.NoError(t, err)
250+
invalidToken := "invalid-token"
251+
252+
tests := []struct {
253+
name string
254+
token string
255+
keyExists bool
256+
bypassAuth bool
257+
wantStatus int
258+
}{
259+
{
260+
name: "Valid token with current secret",
261+
token: validToken,
262+
keyExists: true,
263+
wantStatus: http.StatusOK,
264+
},
265+
{
266+
name: "Valid token with legacy secret",
267+
token: legacyToken,
268+
keyExists: true,
269+
wantStatus: http.StatusOK,
270+
},
271+
{
272+
name: "Invalid token",
273+
token: invalidToken,
274+
keyExists: true,
275+
wantStatus: http.StatusUnauthorized,
276+
},
277+
{
278+
name: "Empty token",
279+
token: "",
280+
keyExists: true,
281+
wantStatus: http.StatusUnauthorized,
282+
},
283+
{
284+
name: "Valid token but key not in cache",
285+
token: validToken,
286+
keyExists: false,
287+
wantStatus: http.StatusUnauthorized,
288+
},
289+
{
290+
name: "Bypass auth enabled",
291+
token: "",
292+
bypassAuth: true,
293+
wantStatus: http.StatusOK,
294+
},
295+
}
296+
297+
for _, tt := range tests {
298+
t.Run(tt.name, func(t *testing.T) {
299+
// Setup
300+
e := echo.New()
301+
req := httptest.NewRequest(http.MethodGet, "/", nil)
302+
if tt.token != "" {
303+
req.Header.Set("Authorization", "Bearer "+tt.token)
304+
}
305+
rec := httptest.NewRecorder()
306+
c := e.NewContext(req, rec)
307+
308+
// Create mock auth repo with desired behavior
309+
mockRepo := &mockKeyLookup{shouldExist: tt.keyExists}
310+
311+
// Create middleware
312+
middleware := NewEchoAuthMiddleware(nil, mockRepo, currentSecret, legacySecrets, tt.bypassAuth, prometheus.NewRegistry())
313+
314+
// Create test handler
315+
handler := middleware(func(c echo.Context) error {
316+
return c.NoContent(http.StatusOK)
317+
})
318+
319+
// Execute
320+
_ = handler(c)
321+
322+
// Assert response status code
323+
assert.Equal(t, tt.wantStatus, rec.Code)
324+
325+
// For successful non-bypassed auth, verify claims are set
326+
if tt.wantStatus == http.StatusOK && !tt.bypassAuth {
327+
claims := c.Get(tokenClaims.String()).(*domain.Claims)
328+
assert.Equal(t, validClaims.APIKey, claims.APIKey)
329+
}
330+
})
331+
}
332+
}

middleware/prometheus.go

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package middleware
2+
3+
import (
4+
"github.com/prometheus/client_golang/prometheus"
5+
)
6+
7+
// prometheusAuth is used for tracking prometheus metrics around auth token validation
8+
type prometheusAuth struct {
9+
currentSecretTokens prometheus.Counter
10+
legacySecretTokens prometheus.Counter
11+
}
12+
13+
func newPrometheusAuth(reg *prometheus.Registry) *prometheusAuth {
14+
p := &prometheusAuth{
15+
currentSecretTokens: prometheus.NewCounter(prometheus.CounterOpts{
16+
Name: "ff_proxy_auth_current_secret_tokens_total",
17+
Help: "The total number of auth tokens decoded with the current secret",
18+
}),
19+
legacySecretTokens: prometheus.NewCounter(prometheus.CounterOpts{
20+
Name: "ff_proxy_auth_legacy_secret_tokens_total",
21+
Help: "The total number of auth tokens decoded with a legacy secret",
22+
}),
23+
}
24+
25+
reg.MustRegister(p.currentSecretTokens, p.legacySecretTokens)
26+
return p
27+
}

transport/http_server_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func setupHTTPServer(t *testing.T, bypassAuth bool, opts ...setupOpts) *HTTPServ
308308
middleware.AllowQuerySemicolons(),
309309
middleware.NewEchoRequestIDMiddleware(),
310310
middleware.NewEchoLoggingMiddleware(logger),
311-
middleware.NewEchoAuthMiddleware(logger, repo, []byte(`secret`), bypassAuth),
311+
middleware.NewEchoAuthMiddleware(logger, repo, []byte(`secret`), [][]byte{}, bypassAuth, prometheus.NewRegistry()),
312312
middleware.ValidateEnvironment(bypassAuth),
313313
)
314314
return server

0 commit comments

Comments
 (0)