Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,27 @@ func initConfig() {
v.SetDefault("openapi-definitions-path", "./bin/definitions")
v.SetDefault("enable-kcp", true)
v.SetDefault("local-development", false)
v.SetDefault("introspection-authentication", false)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved under the gateway section


// Listener
v.SetDefault("listener-apiexport-workspace", ":root")
v.SetDefault("listener-apiexport-name", "kcp.io")

// Gateway
v.SetDefault("gateway-port", "8080")

v.SetDefault("gateway-username-claim", "email")
v.SetDefault("gateway-should-impersonate", true)
v.SetDefault("gateway-introspection-authentication", false)

// Gateway Handler config
v.SetDefault("gateway-handler-pretty", true)
v.SetDefault("gateway-handler-playground", true)
v.SetDefault("gateway-handler-graphiql", true)

// Gateway CORS
v.SetDefault("gateway-cors-enabled", false)
v.SetDefault("gateway-cors-allowed-origins", "*")
v.SetDefault("gateway-cors-allowed-headers", "*")

// Gateway URL
v.SetDefault("gateway-url-virtual-workspace-prefix", "virtual-workspace")
v.SetDefault("gateway-url-default-kcp-workspace", "root")
Expand Down
14 changes: 7 additions & 7 deletions common/config/config.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package config

type Config struct {
OpenApiDefinitionsPath string `mapstructure:"openapi-definitions-path"`
EnableKcp bool `mapstructure:"enable-kcp"`
LocalDevelopment bool `mapstructure:"local-development"`
IntrospectionAuthentication bool `mapstructure:"introspection-authentication"`
OpenApiDefinitionsPath string `mapstructure:"openapi-definitions-path"`
EnableKcp bool `mapstructure:"enable-kcp"`
LocalDevelopment bool `mapstructure:"local-development"`

Url struct {
VirtualWorkspacePrefix string `mapstructure:"gateway-url-virtual-workspace-prefix"`
Expand All @@ -17,9 +16,10 @@ type Config struct {
} `mapstructure:",squash"`

Gateway struct {
Port string `mapstructure:"gateway-port"`
UsernameClaim string `mapstructure:"gateway-username-claim"`
ShouldImpersonate bool `mapstructure:"gateway-should-impersonate"`
Port string `mapstructure:"gateway-port"`
UsernameClaim string `mapstructure:"gateway-username-claim"`
ShouldImpersonate bool `mapstructure:"gateway-should-impersonate"`
IntrospectionAuthentication bool `mapstructure:"gateway-introspection-authentication"`

HandlerCfg struct {
Pretty bool `mapstructure:"gateway-handler-pretty"`
Expand Down
12 changes: 6 additions & 6 deletions common/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestConfig_StructInitialization(t *testing.T) {
assert.Empty(t, cfg.OpenApiDefinitionsPath)
assert.False(t, cfg.EnableKcp)
assert.False(t, cfg.LocalDevelopment)
assert.False(t, cfg.IntrospectionAuthentication)
assert.False(t, cfg.Gateway.IntrospectionAuthentication)

// Test nested struct fields
assert.Empty(t, cfg.Url.VirtualWorkspacePrefix)
Expand All @@ -37,10 +37,9 @@ func TestConfig_StructInitialization(t *testing.T) {

func TestConfig_FieldAssignment(t *testing.T) {
cfg := Config{
OpenApiDefinitionsPath: "/path/to/definitions",
EnableKcp: true,
LocalDevelopment: true,
IntrospectionAuthentication: true,
OpenApiDefinitionsPath: "/path/to/definitions",
EnableKcp: true,
LocalDevelopment: true,
}

cfg.Url.VirtualWorkspacePrefix = "workspace"
Expand All @@ -52,6 +51,7 @@ func TestConfig_FieldAssignment(t *testing.T) {
cfg.Gateway.Port = "8080"
cfg.Gateway.UsernameClaim = "email"
cfg.Gateway.ShouldImpersonate = true
cfg.Gateway.IntrospectionAuthentication = true

cfg.Gateway.HandlerCfg.Pretty = true
cfg.Gateway.HandlerCfg.Playground = true
Expand All @@ -65,7 +65,7 @@ func TestConfig_FieldAssignment(t *testing.T) {
assert.Equal(t, "/path/to/definitions", cfg.OpenApiDefinitionsPath)
assert.True(t, cfg.EnableKcp)
assert.True(t, cfg.LocalDevelopment)
assert.True(t, cfg.IntrospectionAuthentication)
assert.True(t, cfg.Gateway.IntrospectionAuthentication)

assert.Equal(t, "workspace", cfg.Url.VirtualWorkspacePrefix)
assert.Equal(t, "default", cfg.Url.DefaultKcpWorkspace)
Expand Down
9 changes: 1 addition & 8 deletions gateway/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import (

"github.com/pkg/errors"
"github.com/platform-mesh/golang-commons/logger"
"k8s.io/client-go/rest"

appConfig "github.com/platform-mesh/kubernetes-graphql-gateway/common/config"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/manager/roundtripper"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/manager/targetcluster"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/manager/watcher"
)
Expand All @@ -24,12 +22,7 @@ type Service struct {

// NewGateway creates a new domain-driven Gateway instance
func NewGateway(ctx context.Context, log *logger.Logger, appCfg appConfig.Config) (*Service, error) {
// Create round tripper factory
roundTripperFactory := targetcluster.RoundTripperFactory(func(adminRT http.RoundTripper, tlsConfig rest.TLSClientConfig) http.RoundTripper {
return roundtripper.New(log, appCfg, adminRT, roundtripper.NewUnauthorizedRoundTripper())
})

clusterRegistry := targetcluster.NewClusterRegistry(log, appCfg, roundTripperFactory)
clusterRegistry := targetcluster.NewClusterRegistry(log, appCfg)

schemaWatcher, err := watcher.NewFileWatcher(log, clusterRegistry)
if err != nil {
Expand Down
25 changes: 21 additions & 4 deletions gateway/manager/roundtripper/roundtripper.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package roundtripper

import (
"fmt"
"net/http"
"strings"

"github.com/golang-jwt/jwt/v5"
"github.com/platform-mesh/golang-commons/logger"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/client-go/rest"
"k8s.io/client-go/transport"

"github.com/platform-mesh/kubernetes-graphql-gateway/common/config"
Expand All @@ -16,15 +19,17 @@ type TokenKey struct{}
type roundTripper struct {
log *logger.Logger
adminRT, unauthorizedRT http.RoundTripper
baseRT http.RoundTripper
appCfg config.Config
}

type unauthorizedRoundTripper struct{}

func New(log *logger.Logger, appCfg config.Config, adminRoundTripper, unauthorizedRT http.RoundTripper) http.RoundTripper {
func New(log *logger.Logger, appCfg config.Config, adminRoundTripper, baseRoundTripper, unauthorizedRT http.RoundTripper) http.RoundTripper {
return &roundTripper{
log: log,
adminRT: adminRoundTripper,
baseRT: baseRoundTripper,
unauthorizedRT: unauthorizedRT,
appCfg: appCfg,
}
Expand All @@ -35,6 +40,18 @@ func NewUnauthorizedRoundTripper() http.RoundTripper {
return &unauthorizedRoundTripper{}
}

// NewBaseRoundTripper creates a base HTTP transport with only TLS configuration (no authentication)
func NewBaseRoundTripper(tlsConfig rest.TLSClientConfig) (http.RoundTripper, error) {
return rest.TransportFor(&rest.Config{
TLSClientConfig: rest.TLSClientConfig{
Insecure: tlsConfig.Insecure,
ServerName: tlsConfig.ServerName,
CAFile: tlsConfig.CAFile,
CAData: tlsConfig.CAData,
},
})
}

func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.log.Info().
Str("req.Host", req.Host).
Expand Down Expand Up @@ -65,13 +82,13 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}

// No we are going to use token based auth only, so we are reassigning the headers
req = utilnet.CloneRequest(req)
req.Header.Del("Authorization")
req.Header.Set("Authorization", "Bearer "+token)

if !rt.appCfg.Gateway.ShouldImpersonate {
rt.log.Debug().Str("path", req.URL.Path).Msg("Using bearer token authentication")

return rt.adminRT.RoundTrip(req)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used adminRT here and it worked with any token passed since adminRT works without tokens.

fmt.Println(token)
return transport.NewBearerAuthRoundTripper(token, rt.baseRT).RoundTrip(req)
}

// Impersonation mode: extract user from token and impersonate
Expand Down
60 changes: 9 additions & 51 deletions gateway/manager/roundtripper/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestRoundTripper_RoundTrip(t *testing.T) {
appCfg.Gateway.ShouldImpersonate = tt.shouldImpersonate
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "http://example.com/api/v1/pods", nil)
if tt.token != "" {
Expand Down Expand Up @@ -262,7 +262,7 @@ func TestRoundTripper_DiscoveryRequests(t *testing.T) {
appCfg.Gateway.ShouldImpersonate = false
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(tt.method, "http://example.com"+tt.path, nil)

Expand Down Expand Up @@ -376,7 +376,7 @@ func TestRoundTripper_ComprehensiveFunctionality(t *testing.T) {
appCfg.Gateway.ShouldImpersonate = tt.shouldImpersonate
appCfg.Gateway.UsernameClaim = tt.usernameClaim

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "http://example.com/api/v1/pods", nil)
if tt.token != "" {
Expand Down Expand Up @@ -451,7 +451,7 @@ func TestRoundTripper_KCPDiscoveryRequests(t *testing.T) {
appCfg.Gateway.ShouldImpersonate = false
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "http://example.com"+tt.path, nil)

Expand Down Expand Up @@ -500,7 +500,7 @@ func TestRoundTripper_InvalidTokenSecurityFix(t *testing.T) {
appCfg.Gateway.ShouldImpersonate = false
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "/api/v1/pods", nil)
// Don't set a token to simulate the invalid token case
Expand All @@ -510,46 +510,6 @@ func TestRoundTripper_InvalidTokenSecurityFix(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}

func TestRoundTripper_ExistingAuthHeadersAreCleanedBeforeTokenAuth(t *testing.T) {
// This test verifies that existing Authorization headers are properly cleaned
// before setting the bearer token, preventing admin credentials from leaking through

mockAdmin := &mocks.MockRoundTripper{}
mockUnauthorized := &mocks.MockRoundTripper{}

// Capture the request that gets sent to adminRT
var capturedRequest *http.Request
mockAdmin.EXPECT().RoundTrip(mock.Anything).Return(&http.Response{StatusCode: http.StatusOK}, nil).Run(func(req *http.Request) {
capturedRequest = req
})

appCfg := appConfig.Config{}
appCfg.Gateway.ShouldImpersonate = false
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "/api/v1/pods", nil)

// Set an existing Authorization header that should be cleaned
req.Header.Set("Authorization", "Bearer admin-token-that-should-be-removed")

// Add the token to context
req = req.WithContext(context.WithValue(req.Context(), roundtripper.TokenKey{}, "user-token"))

resp, err := rt.RoundTrip(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

// Verify that the captured request has the correct Authorization header
require.NotNil(t, capturedRequest)
authHeader := capturedRequest.Header.Get("Authorization")
assert.Equal(t, "Bearer user-token", authHeader)

// Verify that the original admin token was removed
assert.NotContains(t, authHeader, "admin-token-that-should-be-removed")
}

func TestRoundTripper_ExistingAuthHeadersAreCleanedBeforeImpersonation(t *testing.T) {
// This test verifies that existing Authorization headers are properly cleaned
// before setting the bearer token in impersonation mode
Expand All @@ -567,7 +527,7 @@ func TestRoundTripper_ExistingAuthHeadersAreCleanedBeforeImpersonation(t *testin
appCfg.Gateway.ShouldImpersonate = true
appCfg.Gateway.UsernameClaim = "sub"

rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockUnauthorized)
rt := roundtripper.New(testlogger.New().Logger, appCfg, mockAdmin, mockAdmin, mockUnauthorized)

req := httptest.NewRequest(http.MethodGet, "/api/v1/pods", nil)

Expand All @@ -588,15 +548,13 @@ func TestRoundTripper_ExistingAuthHeadersAreCleanedBeforeImpersonation(t *testin
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

// Verify that the captured request has the correct Authorization header
require.NotNil(t, capturedRequest)
authHeader := capturedRequest.Header.Get("Authorization")
assert.Equal(t, "Bearer "+tokenString, authHeader)

// Verify that the original admin token was removed
// Verify malicious Authorization header was removed
authHeader := capturedRequest.Header.Get("Authorization")
assert.NotContains(t, authHeader, "admin-token-that-should-be-removed")

// Verify that the impersonation header is set
// Verify impersonation header is set (adminRT provides admin auth, not user token)
impersonateHeader := capturedRequest.Header.Get("Impersonate-User")
assert.Equal(t, "test-user", impersonateHeader)
}
26 changes: 18 additions & 8 deletions gateway/manager/targetcluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/platform-mesh/kubernetes-graphql-gateway/common/auth"
appConfig "github.com/platform-mesh/kubernetes-graphql-gateway/common/config"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/manager/roundtripper"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/resolver"
"github.com/platform-mesh/kubernetes-graphql-gateway/gateway/schema"
)
Expand Down Expand Up @@ -64,7 +65,6 @@ func NewTargetCluster(
schemaFilePath string,
log *logger.Logger,
appCfg appConfig.Config,
roundTripperFactory func(http.RoundTripper, rest.TLSClientConfig) http.RoundTripper,
) (*TargetCluster, error) {
fileData, err := readSchemaFile(schemaFilePath)
if err != nil {
Expand All @@ -78,7 +78,7 @@ func NewTargetCluster(
}

// Connect to cluster - use metadata if available, otherwise fall back to standard config
if err := cluster.connect(appCfg, fileData.ClusterMetadata, roundTripperFactory); err != nil {
if err := cluster.connect(appCfg, fileData.ClusterMetadata); err != nil {
return nil, fmt.Errorf("failed to connect to cluster: %w", err)
}

Expand All @@ -96,7 +96,7 @@ func NewTargetCluster(
}

// connect establishes connection to the target cluster
func (tc *TargetCluster) connect(appCfg appConfig.Config, metadata *ClusterMetadata, roundTripperFactory func(http.RoundTripper, rest.TLSClientConfig) http.RoundTripper) error {
func (tc *TargetCluster) connect(appCfg appConfig.Config, metadata *ClusterMetadata) error {
// All clusters now use metadata from schema files to get kubeconfig
if metadata == nil {
return fmt.Errorf("cluster %s requires cluster metadata in schema file", tc.name)
Expand All @@ -114,11 +114,21 @@ func (tc *TargetCluster) connect(appCfg appConfig.Config, metadata *ClusterMetad
return fmt.Errorf("failed to build config from metadata: %w", err)
}

if roundTripperFactory != nil {
tc.restCfg.Wrap(func(rt http.RoundTripper) http.RoundTripper {
return roundTripperFactory(rt, tc.restCfg.TLSClientConfig)
})
}
tc.restCfg.Wrap(func(adminRT http.RoundTripper) http.RoundTripper {
baseRT, err := roundtripper.NewBaseRoundTripper(tc.restCfg.TLSClientConfig)
if err != nil {
tc.log.Error().Err(err).Msg("Failed to create base transport, falling back to default transport")
baseRT = http.DefaultTransport
}

return roundtripper.New(
tc.log,
tc.appCfg,
adminRT,
baseRT,
roundtripper.NewUnauthorizedRoundTripper(),
)
})

// Create client - use KCP-aware client only for KCP mode, standard client otherwise
if appCfg.EnableKcp {
Expand Down
Loading
Loading