Skip to content

Commit

Permalink
Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Dec 27, 2024
1 parent 83d2914 commit 71e9838
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 54 deletions.
41 changes: 20 additions & 21 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
select {
case <-ctx.Done():
err := ctx.Err().Error()
c.logger.Error("context done", "error", err)
c.logger.Warn("context done", "error", err)
http.Error(rw, err, http.StatusInternalServerError)
return
case <-closeCh:
Expand Down Expand Up @@ -176,12 +176,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
sleepDuration(float64(size), float64(t.RateLimitPerSecond))
}

err = c.redirectOrRedirect(rw, r, info.Blobs, info, size)
if err != nil {
c.logger.Error("failed to redirect", "digest", info.Blobs, "error", err)
c.errorResponse(rw, r, ctx.Err())
return
}
c.redirectOrRedirect(rw, r, info.Blobs, info, size)
return
}
c.logger.Info("Cache miss", "digest", info.Blobs)
Expand All @@ -194,7 +189,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t

go func() {
defer doneCache()
size, err := c.cacheBlob(r, info, func(size int64) {
size, err := c.cacheBlob(info, func(size int64) {
signalCh <- signal{
size: size,
}
Expand Down Expand Up @@ -227,11 +222,12 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
select {
case <-ctx.Done():
return
case <-signalCh:
err = c.redirectOrRedirect(rw, r, info.Blobs, info, signal.size)
if err != nil {
c.logger.Error("failed to redirect", "digest", info.Blobs, "error", err)
case signal := <-signalCh:
if signal.err != nil {
c.errorResponse(rw, r, signal.err)
return
}
c.redirectOrRedirect(rw, r, info.Blobs, info, signal.size)
}
return
}
Expand All @@ -248,18 +244,18 @@ func sleepDuration(size, limit float64) {
}
}

func (c *Agent) cacheBlob(r *http.Request, info *BlobInfo, stats func(int64)) (int64, error) {
u := url.URL{
func (c *Agent) cacheBlob(info *BlobInfo, stats func(int64)) (int64, error) {
u := &url.URL{
Scheme: "https",
Host: info.Host,
Path: fmt.Sprintf("/v2/%s/blobs/%s", info.Image, info.Blobs),
}
r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u.String(), nil)
forwardReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u.String(), nil)
if err != nil {
return 0, err
}

resp, err := c.httpClient.Do(r)
resp, err := c.httpClient.Do(forwardReq)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -296,16 +292,18 @@ func (c *Agent) errorResponse(rw http.ResponseWriter, r *http.Request, err error
errcode.ServeJSON(rw, err)
}

func (c *Agent) redirectOrRedirect(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, size int64) error {
func (c *Agent) redirectOrRedirect(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, size int64) {
if int64(c.blobsLENoAgent) > size {
data, err := c.cache.GetBlobContent(r.Context(), info.Blobs)
if err != nil {
return err
c.logger.Error("failed to get blob", "digest", info.Blobs, "error", err)
c.errorResponse(rw, r, err)
return
}
rw.Header().Set("Content-Length", strconv.FormatInt(size, 10))
rw.Header().Set("Content-Type", "application/octet-stream")
rw.Write(data)
return nil
return
}

referer := r.RemoteAddr
Expand All @@ -315,9 +313,10 @@ func (c *Agent) redirectOrRedirect(rw http.ResponseWriter, r *http.Request, blob

u, err := c.cache.RedirectBlob(r.Context(), blob, referer)
if err != nil {
return err
c.logger.Error("failed to get redirect", "digest", info.Blobs, "error", err)
c.errorResponse(rw, r, err)
return
}
c.logger.Info("Cache hit", "digest", blob, "url", u)
http.Redirect(rw, r, u, http.StatusTemporaryRedirect)
return nil
}
10 changes: 6 additions & 4 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Gateway struct {
overrideDefaultRegistry map[string]string

acceptsItems []string
acceptsStr string
accepts map[string]struct{}

blobsLENoAgent int
Expand Down Expand Up @@ -127,6 +128,7 @@ func NewGateway(opts ...Option) (*Gateway, error) {
for _, item := range c.acceptsItems {
c.accepts[item] = struct{}{}
}
c.acceptsStr = strings.Join(c.acceptsItems, ",")

for _, opt := range opts {
opt(c)
Expand Down Expand Up @@ -288,19 +290,19 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
return
}
u := url.URL{
u := &url.URL{
Scheme: "https",
Host: info.Host,
Path: path,
}
r, err = http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil)
forwardReq, err := http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil)
if err != nil {
c.logger.Warn("failed to new request", "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
return
}

resp, err := c.httpClient.Do(r)
resp, err := c.httpClient.Do(forwardReq)
if err != nil {
c.logger.Warn("failed to request", "host", info.Host, "image", info.Image, "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
Expand Down Expand Up @@ -333,7 +335,7 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf
}
rw.WriteHeader(resp.StatusCode)

if r.Method != http.MethodHead {
if forwardReq.Method != http.MethodHead {
var body io.Reader = resp.Body

if t.RateLimitPerSecond > 0 {
Expand Down
47 changes: 18 additions & 29 deletions gateway/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import (
"net/textproto"
"net/url"
"strconv"
"strings"
"time"

"github.com/daocloud/crproxy/token"
"github.com/docker/distribution/registry/api/errcode"
)

func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) {
if c.tryFirstServeCachedManifest(rw, r, info) {
done, fallback := c.tryFirstServeCachedManifest(rw, r, info)
if done {
return
}

Expand All @@ -41,15 +41,16 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,
if info.IsDigestManifests {
forwardReq.Header.Set("Accept", r.Header.Get("Accept"))
} else {
forwardReq.Header.Set("Accept", strings.Join(c.acceptsItems, ","))
forwardReq.Header.Set("Accept", c.acceptsStr)
}

resp, err := c.httpClient.Do(forwardReq)
if err != nil {
if c.fallbackServeCachedManifest(rw, r, info) {
if fallback && c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Warn("failed to request, but hit caches", "url", u.String(), "error", err)
return
}
c.logger.Error("failed to request", "url", u, "error", err)
c.logger.Error("failed to request", "url", u.String(), "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
return
}
Expand All @@ -59,33 +60,21 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,

switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 40x, but hit caches", "url", u, "response", dumpResponse(resp))
if fallback && c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Warn("origin manifest response, but hit caches", "statusCode", resp.StatusCode, "url", u.String(), "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 40x", "url", u, "response", dumpResponse(resp))
c.logger.Error("origin manifest response", "statusCode", resp.StatusCode, "url", u.String(), "response", dumpResponse(resp))
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
return
}

if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 4xx, but hit caches", "url", u, "response", dumpResponse(resp))
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
if fallback && c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Warn("origin manifest response, but hit caches", "statusCode", resp.StatusCode, "url", u.String(), "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 4xx", "url", u)
} else if resp.StatusCode >= http.StatusInternalServerError {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 5xx, but hit caches", "url", u, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 5xx", "url", u)
} else if resp.StatusCode < http.StatusOK {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 1xx, but hit caches", "url", u, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 1xx", "url", u)
c.logger.Error("origin manifest response", "statusCode", resp.StatusCode, "url", u.String())
}

resp.Header.Del("Docker-Ratelimit-Source")
Expand Down Expand Up @@ -130,19 +119,19 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,
}
}

func (c *Gateway) tryFirstServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool {
func (c *Gateway) tryFirstServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) (done bool, fallback bool) {
if !info.IsDigestManifests && c.manifestCacheDuration > 0 {
last, ok := c.manifestCache.Load(manifestCacheKey(info))
if !ok {
return false
return false, true
}

if time.Since(last) > c.manifestCacheDuration {
return false
return false, true
}
}

return c.serveCachedManifest(rw, r, info)
return c.serveCachedManifest(rw, r, info), false
}

func (c *Gateway) fallbackServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool {
Expand All @@ -162,7 +151,7 @@ func (c *Gateway) serveCachedManifest(rw http.ResponseWriter, r *http.Request, i
return false
}

c.logger.Info("Manifest cache hit", "digest", digest)
c.logger.Info("Manifest cache hit", "host", info.Host, "image", info.Blobs, "manifest", info.Manifests, "digest", digest)
rw.Header().Set("Docker-Content-Digest", digest)
rw.Header().Set("Content-Type", mediaType)
rw.Header().Set("Content-Length", strconv.FormatInt(int64(len(content)), 10))
Expand Down

0 comments on commit 71e9838

Please sign in to comment.