From ddde90f082836d1e7f6e21fa82621c5ba430579e Mon Sep 17 00:00:00 2001 From: Harold Simpson Date: Fri, 13 Sep 2024 11:55:54 +0200 Subject: [PATCH] PR feedback + remove CloseIdleConnections --- README.md | 1 + pkg/smokescreen/acl/v1/acl.go | 85 ++++++++++++++------------- pkg/smokescreen/acl/v1/acl_test.go | 2 +- pkg/smokescreen/acl/v1/yaml_loader.go | 27 ++++----- pkg/smokescreen/config_loader.go | 8 ++- pkg/smokescreen/smokescreen.go | 64 +++++++++----------- pkg/smokescreen/smokescreen_test.go | 6 +- 7 files changed, 97 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index b0624f8f..1ab88070 100644 --- a/README.md +++ b/README.md @@ -186,3 +186,4 @@ See [Development.md](Development.md) - Evan Broder - Marc-André Tremblay - Ryan Koppenhaver +- Harold Simpson diff --git a/pkg/smokescreen/acl/v1/acl.go b/pkg/smokescreen/acl/v1/acl.go index 6540cdc8..5c48b643 100644 --- a/pkg/smokescreen/acl/v1/acl.go +++ b/pkg/smokescreen/acl/v1/acl.go @@ -30,8 +30,10 @@ type Rule struct { } type MitmDomain struct { - MitmConfig - Domain string + AddHeaders map[string]string + DetailedHttpLogs bool + DetailedHttpLogsFullHeaders []string + Domain string } type MitmConfig struct { @@ -142,8 +144,12 @@ func (acl *ACL) Decide(service, host, connectProxyHost string) (Decision, error) // if the host matches any of the rule's allowed domains with MITM config, allow for _, dg := range rule.DomainMitmGlobs { if HostMatchesGlob(host, dg.Domain) { - d.Result, d.Reason = Allow, "host matched allowed domain in rule" - d.MitmConfig = (*MitmConfig)(&dg.MitmConfig) + d.Result, d.Reason = Allow, "host matched allowed domain in MITM rule" + d.MitmConfig = &MitmConfig{ + AddHeaders: dg.AddHeaders, + DetailedHttpLogs: dg.DetailedHttpLogs, + DetailedHttpLogsFullHeaders: dg.DetailedHttpLogsFullHeaders, + } return d, nil } } @@ -215,59 +221,58 @@ func (acl *ACL) Validate() error { } func (acl *ACL) ValidateRuleDomainsGlobs(svc string, r Rule) error { - err := acl.ValidateDomainGlobs(svc, r.DomainGlobs) - if err != nil { - return err - } - mitmDomainGlobs := make([]string, len(r.DomainMitmGlobs)) - for i, d := range r.DomainMitmGlobs { - mitmDomainGlobs[i] = d.Domain + var err error + for _, d := range r.DomainGlobs { + err = acl.ValidateDomainGlob(svc, d) + if err != nil { + return err + } } - err = acl.ValidateDomainGlobs(svc, mitmDomainGlobs) - if err != nil { - return err + for _, d := range r.DomainMitmGlobs { + err = acl.ValidateDomainGlob(svc, d.Domain) + if err != nil { + return err + } } return nil } -// ValidateDomainGlobs takes a slice of domain globs and verifies they conform to smokescreen's +// ValidateDomainGlob takes a domain glob and verifies they conform to smokescreen's // domain glob policy. // // Wildcards are valid only at the beginning of a domain glob, and only a single wildcard per glob // pattern is allowed. Globs must include text after a wildcard. // // Domains must use their normalized form (e.g., Punycode) -func (acl *ACL) ValidateDomainGlobs(svc string, globs []string) error { - for _, glob := range globs { - if glob == "" { - return fmt.Errorf("glob cannot be empty") - } +func (*ACL) ValidateDomainGlob(svc string, glob string) error { + if glob == "" { + return fmt.Errorf("glob cannot be empty") + } - if glob == "*" || glob == "*." { - return fmt.Errorf("%v: %v: domain glob must not match everything", svc, glob) - } + if glob == "*" || glob == "*." { + return fmt.Errorf("%v: %v: domain glob must not match everything", svc, glob) + } - if !strings.HasPrefix(glob, "*.") && strings.HasPrefix(glob, "*") { - return fmt.Errorf("%v: %v: domain glob must represent a full prefix (sub)domain", svc, glob) - } + if !strings.HasPrefix(glob, "*.") && strings.HasPrefix(glob, "*") { + return fmt.Errorf("%v: %v: domain glob must represent a full prefix (sub)domain", svc, glob) + } - domainToCheck := strings.TrimPrefix(glob, "*") - if strings.Contains(domainToCheck, "*") { - return fmt.Errorf("%v: %v: domain globs are only supported as prefix", svc, glob) - } + domainToCheck := strings.TrimPrefix(glob, "*") + if strings.Contains(domainToCheck, "*") { + return fmt.Errorf("%v: %v: domain globs are only supported as prefix", svc, glob) + } - normalizedDomain, err := hostport.NormalizeHost(domainToCheck, false) + normalizedDomain, err := hostport.NormalizeHost(domainToCheck, false) - if err != nil { - return fmt.Errorf("%v: %v: incorrect ACL entry: %v", svc, glob, err) - } else if normalizedDomain != domainToCheck { - // There was no error but the config contains a non-normalized form - if strings.HasPrefix(glob, "*.") { - // (Re-add) wildcard if one was provided (for the error message) - normalizedDomain = "*." + normalizedDomain - } - return fmt.Errorf("%v: %v: incorrect ACL entry; use %q", svc, glob, normalizedDomain) + if err != nil { + return fmt.Errorf("%v: %v: incorrect ACL entry: %v", svc, glob, err) + // There was no error but the config contains a non-normalized form + } else if normalizedDomain != domainToCheck { + if strings.HasPrefix(glob, "*.") { + // (Re-add) wildcard if one was provided (for the error message) + normalizedDomain = "*." + normalizedDomain } + return fmt.Errorf("%v: %v: incorrect ACL entry; use %q", svc, glob, normalizedDomain) } return nil } diff --git a/pkg/smokescreen/acl/v1/acl_test.go b/pkg/smokescreen/acl/v1/acl_test.go index c299e700..bbb2b84b 100644 --- a/pkg/smokescreen/acl/v1/acl_test.go +++ b/pkg/smokescreen/acl/v1/acl_test.go @@ -377,7 +377,7 @@ func TestMitmComfig(t *testing.T) { d, err := acl.Decide(mitmService, "example-mitm.com", "") a.NoError(err) a.Equal(Allow, d.Result) - a.Equal("host matched allowed domain in rule", d.Reason) + a.Equal("host matched allowed domain in MITM rule", d.Reason) a.NotNil(d.MitmConfig) a.Equal(true, d.MitmConfig.DetailedHttpLogs) diff --git a/pkg/smokescreen/acl/v1/yaml_loader.go b/pkg/smokescreen/acl/v1/yaml_loader.go index db39803d..13554f30 100644 --- a/pkg/smokescreen/acl/v1/yaml_loader.go +++ b/pkg/smokescreen/acl/v1/yaml_loader.go @@ -89,14 +89,7 @@ func (cfg *YAMLConfig) Load() (*ACL, error) { var allowedHostsMitm []MitmDomain for _, w := range v.AllowedHostsMitm { - mitmDomain := MitmDomain{ - MitmConfig: MitmConfig{ - AddHeaders: w.AddHeaders, - DetailedHttpLogs: w.DetailedHttpLogs, - DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders, - }, - Domain: w.Domain, - } + mitmDomain := NewMITMDomain(w) allowedHostsMitm = append(allowedHostsMitm, mitmDomain) } @@ -123,14 +116,7 @@ func (cfg *YAMLConfig) Load() (*ACL, error) { var allowedHostsMitm []MitmDomain for _, w := range cfg.Default.AllowedHostsMitm { - mitmDomain := MitmDomain{ - MitmConfig: MitmConfig{ - AddHeaders: w.AddHeaders, - DetailedHttpLogs: w.DetailedHttpLogs, - DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders, - }, - Domain: w.Domain, - } + mitmDomain := NewMITMDomain(w) allowedHostsMitm = append(allowedHostsMitm, mitmDomain) } @@ -155,3 +141,12 @@ func (cfg *YAMLConfig) Load() (*ACL, error) { return &acl, nil } + +func NewMITMDomain(w YAMLMitmRule) MitmDomain { + return MitmDomain{ + AddHeaders: w.AddHeaders, + DetailedHttpLogs: w.DetailedHttpLogs, + DetailedHttpLogsFullHeaders: w.DetailedHttpLogsFullHeaders, + Domain: w.Domain, + } +} diff --git a/pkg/smokescreen/config_loader.go b/pkg/smokescreen/config_loader.go index 3734a26d..e9b6f715 100644 --- a/pkg/smokescreen/config_loader.go +++ b/pkg/smokescreen/config_loader.go @@ -165,10 +165,14 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { } mitmCa, err := tls.LoadX509KeyPair(yc.MitmCaCertFile, yc.MitmCaKeyFile) if err != nil { - return fmt.Errorf("could not load mitmCa: %v", err) + return fmt.Errorf("mitm_ca_key_file error tls.LoadX509KeyPair: %w", err) + } + // set the leaf certificat to reduce per-handshake processing + if len(mitmCa.Certificate) == 0 { + return errors.New("mitm_ca_key_file error: mitm_ca_key_file contains no certificates") } if mitmCa.Leaf, err = x509.ParseCertificate(mitmCa.Certificate[0]); err != nil { - return fmt.Errorf("could not populate x509 Leaf value: %v", err) + return fmt.Errorf("could not populate x509 Leaf value: %w", err) } c.MitmTLSConfig = goproxy.TLSConfigFromCA(&mitmCa) } diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index b71b35ba..5dc95135 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -289,10 +289,7 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout) } connTime := time.Since(start) - - fields := logrus.Fields{ - LogFieldConnEstablishMS: connTime.Milliseconds(), - } + sctx.logger = sctx.logger.WithFields(dialContextLoggerFields(pctx, sctx, conn, connTime)) if sctx.cfg.TimeConnect { sctx.cfg.MetricsClient.TimingWithTags("cn.atpt.connect.time", connTime, map[string]string{"domain": sctx.requestedHost}, 1) @@ -307,6 +304,22 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1) sctx.cfg.ConnTracker.RecordAttempt(sctx.requestedHost, true) + // Only wrap CONNECT conns with an InstrumentedConn. Connections used for traditional HTTP proxy + // requests are pooled and reused by net.Transport. + if sctx.proxyType == connectProxy { + ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType) + pctx.ConnErrorHandler = ic.Error + conn = ic + } else { + conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout) + } + + return conn, nil +} +func dialContextLoggerFields(pctx *goproxy.ProxyCtx, sctx *SmokescreenContext, conn net.Conn, connTime time.Duration) logrus.Fields { + fields := logrus.Fields{ + LogFieldConnEstablishMS: connTime.Milliseconds(), + } if conn != nil { if addr := conn.LocalAddr(); addr != nil { fields[LogFieldOutLocalAddr] = addr.String() @@ -316,30 +329,14 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { fields[LogFieldOutRemoteAddr] = addr.String() } } - sctx.logger = sctx.logger.WithFields(fields) - - // Only wrap CONNECT conns and MITM http conns with an InstrumentedConn. Connections used for traditional HTTP proxy - // requests are pooled and reused by net.Transport. - if sctx.proxyType == connectProxy || pctx.ConnectAction == goproxy.ConnectMitm { - // If we have a MITM and option is enabled, we can add detailed Request log fields - if pctx.ConnectAction == goproxy.ConnectMitm && sctx.Decision.MitmConfig != nil && sctx.Decision.MitmConfig.DetailedHttpLogs { - fields := logrus.Fields{ - LogMitmReqUrl: pctx.Req.URL.String(), - LogMitmReqMethod: pctx.Req.Method, - LogMitmReqHeaders: redactHeaders(pctx.Req.Header, sctx.Decision.MitmConfig.DetailedHttpLogsFullHeaders), - } - - sctx.logger = sctx.logger.WithFields(fields) - - } - ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.logger, d.role, d.outboundHost, sctx.proxyType) - pctx.ConnErrorHandler = ic.Error - conn = ic - } else { - conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout) + // If we have a MITM and option is enabled, we can add detailed Request log fields + if pctx.ConnectAction == goproxy.ConnectMitm && sctx.Decision.MitmConfig != nil && sctx.Decision.MitmConfig.DetailedHttpLogs { + fields[LogMitmReqUrl] = pctx.Req.URL.String() + fields[LogMitmReqMethod] = pctx.Req.Method + fields[LogMitmReqHeaders] = redactHeaders(pctx.Req.Header, sctx.Decision.MitmConfig.DetailedHttpLogsFullHeaders) } - return conn, nil + return fields } // HTTPErrorHandler allows returning a custom error response when smokescreen @@ -468,12 +465,14 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { } } - // Handle traditional HTTP proxy + // Handle traditional HTTP proxy and MITM outgoing requests (smokescreen - remote ) proxy.OnRequest().DoFunc(func(req *http.Request, pctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { // Set this on every request as every request mints a new goproxy.ProxyCtx pctx.RoundTripper = rtFn - // For MITM requests intended for the remote host, the sole requirement was to configure the RoundTripper + // In the context of MITM request. Once the originating request (client - smokescreen) has been allowed + // goproxy/https.go calls proxy.filterRequest on the outgoing request (smokescreen - remote host) which calls this function + // in this case we ony want to configure the RoundTripper if pctx.ConnectAction == goproxy.ConnectMitm { return req, nil } @@ -571,13 +570,8 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { return rejectResponse(pctx, pctx.Error) } - if pctx.ConnectAction == goproxy.ConnectMitm { - // If the connection is a MITM - // 1 we don't want to log as it will be done in HandleConnectFunc - // 2 we want to close idle connections as they are not closed by default - // and CANONICAL-PROXY-CN-CLOSE is called on InstrumentedConn.Close - proxy.Tr.CloseIdleConnections() - } else { + // We don't want to log if the connection is a MITM as it will be done in HandleConnectFunc + if pctx.ConnectAction != goproxy.ConnectMitm { // In case of an error, this function is called a second time to filter the // response we generate so this logger will be called once. logProxy(pctx) diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 9c8c618c..0f7b0120 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1432,9 +1432,10 @@ func TestMitm(t *testing.T) { r.NoError(err) cfg.Listener = l - proxy := proxyServer(cfg) + proxy := BuildProxy(cfg) + httpProxy := httptest.NewServer(proxy) remote := httptest.NewTLSServer(h) - client, err := proxyClient(proxy.URL) + client, err := proxyClient(httpProxy.URL) r.NoError(err) req, err := http.NewRequest("GET", remote.URL, nil) @@ -1480,6 +1481,7 @@ func TestMitm(t *testing.T) { r.NotNil(proxyDecision) r.Contains(proxyDecision.Data, "proxy_type") r.Equal("connect", proxyDecision.Data["proxy_type"]) + proxy.Tr.CloseIdleConnections() // check proxyclose log entry has information about the request headers proxyClose := findCanonicalProxyClose(logHook.AllEntries()) r.NotNil(proxyClose)