From cdf47f464bb4ff2f99bae2dacea14719eb06e414 Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Mon, 18 Mar 2024 11:23:46 +0800 Subject: [PATCH] Add --disable-keep-alives --- cmd/crproxy/main.go | 3 ++ crproxy.go | 70 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index 3d4f20f..f52d139 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -19,6 +19,7 @@ import ( var ( address string userpass []string + disableKeepAlives []string blobsSpeedLimit string totalBlobsSpeedLimit string ) @@ -26,6 +27,7 @@ var ( func init() { pflag.StringSliceVarP(&userpass, "user", "u", nil, "host and username and password -u user:pwd@host") pflag.StringVarP(&address, "address", "a", ":8080", "listen on the address") + pflag.StringSliceVar(&disableKeepAlives, "disable-keep-alives", nil, "disable keep alives for the host") pflag.StringVar(&blobsSpeedLimit, "blobs-speed-limit", "", "blobs speed limit per second (default unlimited)") pflag.StringVar(&totalBlobsSpeedLimit, "total-blobs-speed-limit", "", "total blobs speed limit per second (default unlimited)") pflag.Parse() @@ -88,6 +90,7 @@ func main() { } return info }), + crproxy.WithDisableKeepAlives(disableKeepAlives), } if len(userpass) != 0 { diff --git a/crproxy.go b/crproxy.go index f7260ed..2c03345 100644 --- a/crproxy.go +++ b/crproxy.go @@ -28,20 +28,21 @@ type Logger interface { } type CRProxy struct { - baseClient *http.Client - challengeManager challenge.Manager - clientset map[string]*lru.LRU[string, *http.Client] - clientSize int - modify func(info *PathInfo) *PathInfo - insecureDomain map[string]struct{} - domainAlias map[string]string - userAndPass map[string]Userpass - basicCredentials *basicCredentials - mut sync.Mutex - bytesPool sync.Pool - logger Logger - totalBlobsSpeedLimit *geario.Gear - blobsSpeedLimit *geario.B + baseClient *http.Client + challengeManager challenge.Manager + clientset map[string]*lru.LRU[string, *http.Client] + clientSize int + modify func(info *PathInfo) *PathInfo + insecureDomain map[string]struct{} + domainDisableKeepAlives map[string]struct{} + domainAlias map[string]string + userAndPass map[string]Userpass + basicCredentials *basicCredentials + mut sync.Mutex + bytesPool sync.Pool + logger Logger + totalBlobsSpeedLimit *geario.Gear + blobsSpeedLimit *geario.B } type Option func(c *CRProxy) @@ -94,6 +95,15 @@ func WithMaxClientSizeForEachRegistry(clientSize int) Option { } } +func WithDisableKeepAlives(disableKeepAlives []string) Option { + return func(c *CRProxy) { + c.domainDisableKeepAlives = map[string]struct{}{} + for _, v := range disableKeepAlives { + c.domainDisableKeepAlives[v] = struct{}{} + } + } +} + func NewCRProxy(opts ...Option) (*CRProxy, error) { c := &CRProxy{ challengeManager: challenge.NewSimpleManager(), @@ -151,8 +161,19 @@ func (c *CRProxy) getClientset(host string, image string) *http.Client { credentialStore = c.basicCredentials } authHandler := auth.NewTokenHandler(nil, credentialStore, image, "pull") + + tr := c.baseClient.Transport + + if c.domainDisableKeepAlives != nil { + if _, ok := c.domainDisableKeepAlives[host]; ok { + tr = c.disableKeepAlives(tr) + } + } + + tr = transport.NewTransport(tr, auth.NewAuthorizer(c.challengeManager, authHandler)) + client := &http.Client{ - Transport: transport.NewTransport(c.baseClient.Transport, auth.NewAuthorizer(c.challengeManager, authHandler)), + Transport: tr, CheckRedirect: c.baseClient.CheckRedirect, Timeout: c.baseClient.Timeout, Jar: c.baseClient.Jar, @@ -169,6 +190,25 @@ func (c *CRProxy) getClientset(host string, image string) *http.Client { return client } +func (c *CRProxy) disableKeepAlives(rt http.RoundTripper) http.RoundTripper { + if rt == nil { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DisableKeepAlives = true + return tr + } + if tr, ok := rt.(*http.Transport); ok { + if !tr.DisableKeepAlives { + tr = tr.Clone() + tr.DisableKeepAlives = true + } + return tr + } + if c.logger != nil { + c.logger.Println("failed to disable keep alives") + } + return rt +} + func (c *CRProxy) ping(host string) error { c.mut.Lock() defer c.mut.Unlock()