Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func newDeployCommand() *deployCommand {
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)")
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSRedirect, "tls-redirect", true, "Redirect HTTP traffic to HTTPS")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.CanonicalHost, "canonical-host", "", "Redirect all requests to this host (e.g., force root or www)")

deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target")
Expand Down Expand Up @@ -110,5 +111,12 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error {
}
}

// Validate canonical host is present in hosts when both are specified
if c.args.ServiceOptions.CanonicalHost != "" && len(c.args.ServiceOptions.Hosts) > 0 && c.args.ServiceOptions.Hosts[0] != "" {
if !slices.Contains(c.args.ServiceOptions.Hosts, c.args.ServiceOptions.CanonicalHost) {
return fmt.Errorf("canonical-host '%s' must be present in the hosts list: %v", c.args.ServiceOptions.CanonicalHost, c.args.ServiceOptions.Hosts)
}
}

return nil
}
78 changes: 78 additions & 0 deletions internal/cmd/deploy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package cmd

import (
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDeployCommand_CanonicalHostValidation(t *testing.T) {
tests := []struct {
name string
hosts []string
canonicalHost string
expectError bool
expectedError string
}{
{
name: "valid canonical host in hosts list",
hosts: []string{"example.com", "www.example.com"},
canonicalHost: "example.com",
expectError: false,
},
{
name: "valid canonical host in hosts list with www",
hosts: []string{"example.com", "www.example.com"},
canonicalHost: "www.example.com",
expectError: false,
},
{
name: "canonical host not in hosts list",
hosts: []string{"example.com", "www.example.com"},
canonicalHost: "api.example.com",
expectError: true,
expectedError: "canonical-host 'api.example.com' must be present in the hosts list: [example.com www.example.com]",
},
{
name: "canonical host empty with hosts",
hosts: []string{"example.com", "www.example.com"},
canonicalHost: "",
expectError: false,
},
{
name: "canonical host with no hosts",
hosts: []string{},
canonicalHost: "example.com",
expectError: false,
},
{
name: "both canonical host and hosts empty",
hosts: []string{},
canonicalHost: "",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := newDeployCommand()

cmd.args.ServiceOptions.Hosts = tt.hosts
cmd.args.ServiceOptions.CanonicalHost = tt.canonicalHost
cmd.args.ServiceOptions.TLSEnabled = false

mockCmd := &cobra.Command{}

err := cmd.preRun(mockCmd, []string{"test-service"})

if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedError)
} else {
require.NoError(t, err)
}
})
}
}
43 changes: 43 additions & 0 deletions internal/server/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"encoding/json"
"crypto/tls"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -209,6 +210,48 @@ func TestRouter_UpdatingOptions(t *testing.T) {
assert.Empty(t, body)
}

func TestRouter_CanonicalHostRedirect(t *testing.T) {
router := testRouter(t)
_, target := testBackend(t, "first", http.StatusOK)

serviceOptions := defaultServiceOptions
serviceOptions.Hosts = []string{"example.com", "www.example.com"}
serviceOptions.CanonicalHost = "example.com"

require.NoError(t, router.DeployService("service1", []string{target}, defaultEmptyReaders, serviceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout))

statusCode, _ := sendGETRequest(router, "http://www.example.com/")
assert.Equal(t, http.StatusMovedPermanently, statusCode)

statusCode, body := sendGETRequest(router, "http://example.com/")
assert.Equal(t, http.StatusOK, statusCode)
assert.Equal(t, "first", body)
}

func TestRouter_CanonicalHostRedirectWithTLS(t *testing.T) {
router := testRouter(t)
_, target := testBackend(t, "first", http.StatusOK)

serviceOptions := defaultServiceOptions
serviceOptions.Hosts = []string{"example.com", "www.example.com"}
serviceOptions.CanonicalHost = "example.com"
serviceOptions.TLSEnabled = true
serviceOptions.TLSRedirect = true

require.NoError(t, router.DeployService("service1", []string{target}, defaultEmptyReaders, serviceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout))

// Should go directly to https://example.com in a single redirect
statusCode, _ := sendGETRequest(router, "http://www.example.com/")
assert.Equal(t, http.StatusMovedPermanently, statusCode)

// HTTPS request to non-canonical host should redirect to canonical host but remain HTTPS
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/", nil)
req.TLS = &tls.ConnectionState{}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusMovedPermanently, w.Result().StatusCode)
}

func TestRouter_DeploymentsWithErrorsDoNotUpdateService(t *testing.T) {
router := testRouter(t)
_, target := testBackend(t, "first", http.StatusOK)
Expand Down
48 changes: 36 additions & 12 deletions internal/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type ServiceOptions struct {
TLSCertificatePath string `json:"tls_certificate_path"`
TLSPrivateKeyPath string `json:"tls_private_key_path"`
TLSRedirect bool `json:"tls_redirect"`
CanonicalHost string `json:"canonical_host"`
ACMEDirectory string `json:"acme_directory"`
ACMECachePath string `json:"acme_cache_path"`
ErrorPagePath string `json:"error_page_path"`
Expand Down Expand Up @@ -417,13 +418,12 @@ func (s *Service) createMiddleware(options ServiceOptions, certManager CertManag
func (s *Service) serviceRequestWithTarget(w http.ResponseWriter, r *http.Request) {
LoggingRequestContext(r).Service = s.name

if s.shouldRedirectToHTTPS(r) {
s.redirectToHTTPS(w, r)
if !s.options.TLSEnabled && r.TLS != nil {
SetErrorResponse(w, r, http.StatusServiceUnavailable, nil)
return
}

if !s.options.TLSEnabled && r.TLS != nil {
SetErrorResponse(w, r, http.StatusServiceUnavailable, nil)
if s.handleRedirectsIfNeeded(w, r) {
return
}

Expand All @@ -445,10 +445,6 @@ func (s *Service) startLoadBalancerRequest(w http.ResponseWriter, r *http.Reques
return lb.StartRequest(w, r)
}

func (s *Service) shouldRedirectToHTTPS(r *http.Request) bool {
return s.options.TLSEnabled && s.options.TLSRedirect && r.TLS == nil
}

func (s *Service) handlePausedAndStoppedRequests(w http.ResponseWriter, r *http.Request) bool {
if s.pauseController.GetState() != PauseStateRunning && s.targetOptions.IsHealthCheckRequest(r) {
// When paused or stopped, return success for any health check
Expand All @@ -475,14 +471,42 @@ func (s *Service) handlePausedAndStoppedRequests(w http.ResponseWriter, r *http.
return false
}

func (s *Service) redirectToHTTPS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close")
func (s *Service) handleRedirectsIfNeeded(w http.ResponseWriter, r *http.Request) bool {
if url := s.redirectURLIfNeeded(r); url != "" {
w.Header().Set("Connection", "close")
http.Redirect(w, r, url, http.StatusMovedPermanently)
return true
}
return false
}

// redirectURLIfNeeded returns a full absolute URL to redirect to when either
// TLS redirection or canonical host redirection should occur. If no redirect is
// needed, it returns an empty string.
func (s *Service) redirectURLIfNeeded(r *http.Request) string {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}

url := "https://" + host + r.URL.RequestURI()
http.Redirect(w, r, url, http.StatusMovedPermanently)
currentScheme := "http"
if r.TLS != nil {
currentScheme = "https"
}

desiredScheme := currentScheme
if s.options.TLSEnabled && s.options.TLSRedirect && currentScheme == "http" {
desiredScheme = "https"
}

desiredHost := host
if s.options.CanonicalHost != "" && host != s.options.CanonicalHost {
desiredHost = s.options.CanonicalHost
}

if desiredScheme != currentScheme || desiredHost != host {
return desiredScheme + "://" + desiredHost + r.URL.RequestURI()
}

return ""
}