Skip to content

Commit

Permalink
Update manager
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Dec 23, 2024
1 parent a3e8728 commit 5bb296d
Show file tree
Hide file tree
Showing 26 changed files with 1,086 additions and 237 deletions.
43 changes: 28 additions & 15 deletions cmd/crproxy/cluster/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type flagpole struct {
AnonymousRateLimitPerSecond uint64
AnonymousNoAllowlist bool

AdminToken string

BlobsURLs []string

DBURL string
Expand Down Expand Up @@ -74,6 +76,8 @@ func NewCommand() *cobra.Command {
cmd.Flags().BoolVar(&flags.AllowAnonymous, "allow-anonymous", flags.AllowAnonymous, "Allow anonymous")
cmd.Flags().Uint64Var(&flags.AnonymousRateLimitPerSecond, "anonymous-rate-limit-per-second", flags.AnonymousRateLimitPerSecond, "Rate limit for anonymous users per second")

cmd.Flags().StringVar(&flags.AdminToken, "admin-token", flags.AdminToken, "Admin token")

cmd.Flags().StringSliceVar(&flags.BlobsURLs, "blobs-url", flags.BlobsURLs, "Blobs urls")

cmd.Flags().StringVar(&flags.DBURL, "db-url", flags.DBURL, "Database URL")
Expand Down Expand Up @@ -129,7 +133,9 @@ func runE(ctx context.Context, flags *flagpole) error {
return fmt.Errorf("failed to ping database: %w", err)
}

mgr = manager.NewManager(privateKey, db, 1*time.Minute)
logger.Info("Connected to DB")

mgr = manager.NewManager(privateKey, flags.AdminToken, db, 1*time.Minute)

mgr.Register(container)

Expand All @@ -139,20 +145,8 @@ func runE(ctx context.Context, flags *flagpole) error {
getHosts := getBlobsURLs(flags.BlobsURLs)

authFunc := func(r *http.Request, userinfo *url.Userinfo, t *token.Token) (token.Attribute, bool) {
if userinfo == nil {
if !flags.AllowAnonymous {
return token.Attribute{}, false
}
t.RateLimitPerSecond = flags.AnonymousRateLimitPerSecond

if !t.Block {
t.BlobsURL = getHosts()
}
return t.Attribute, true
}

var has bool
if flags.SimpleAuthUserpass != nil {
if userinfo != nil && flags.SimpleAuthUserpass != nil {
pass, ok := flags.SimpleAuthUserpass[userinfo.Username()]
if ok {
upass, ok := userinfo.Password()
Expand All @@ -172,10 +166,21 @@ func runE(ctx context.Context, flags *flagpole) error {

if !has {
if mgr == nil {
if userinfo == nil {
if !flags.AllowAnonymous {
return token.Attribute{}, false
}
t.RateLimitPerSecond = flags.AnonymousRateLimitPerSecond

if !t.Block {
t.BlobsURL = getHosts()
}
return t.Attribute, true
}
return token.Attribute{}, false
}

attr, err := mgr.GetToken(r.Context(), userinfo, t)
attr, err := mgr.GetTokenWithUser(r.Context(), userinfo, t)
if err != nil {
logger.Info("Failed to retrieve token", "user", userinfo, "err", err)
return token.Attribute{}, false
Expand All @@ -196,11 +201,19 @@ func runE(ctx context.Context, flags *flagpole) error {
container.Handle("/auth/token", gen)

var handler http.Handler = container

handler = handlers.LoggingHandler(os.Stderr, handler)

if flags.Behind {
handler = handlers.ProxyHeaders(handler)
}

handler = handlers.CORS(
handlers.AllowedMethods([]string{http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}),
handlers.AllowedHeaders([]string{"Authorization", "Accept", "Content-Type", "Origin"}),
handlers.AllowedOrigins([]string{"*"}),
)(handler)

err = server.Run(ctx, flags.Address, handler, flags.AcmeHosts, flags.AcmeCacheDir, flags.CertFile, flags.PrivateKeyFile)
if err != nil {
return fmt.Errorf("failed to run server: %w", err)
Expand Down
22 changes: 8 additions & 14 deletions cmd/crproxy/cluster/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ type flagpole struct {
TokenPublicKeyFile string
TokenURL string

DefaultRegistry string
OverrideDefaultRegistry map[string]string

ReadmeURL string
}

Expand Down Expand Up @@ -78,9 +75,6 @@ func NewCommand() *cobra.Command {
cmd.Flags().StringVar(&flags.TokenPublicKeyFile, "token-public-key-file", flags.TokenPublicKeyFile, "Token public key file")
cmd.Flags().StringVar(&flags.TokenURL, "token-url", flags.TokenURL, "Token url")

cmd.Flags().StringVar(&flags.DefaultRegistry, "default-registry", flags.DefaultRegistry, "Default registry")
cmd.Flags().StringToStringVar(&flags.OverrideDefaultRegistry, "override-default-registry", flags.OverrideDefaultRegistry, "Override default registry")

cmd.Flags().StringVar(&flags.ReadmeURL, "readme-url", flags.ReadmeURL, "Readme url")
return cmd
}
Expand All @@ -94,23 +88,23 @@ func runE(ctx context.Context, flags *flagpole) error {

opts = append(opts,
gateway.WithLogger(logger),
gateway.WithDomainAlias(map[string]string{
"docker.io": "registry-1.docker.io",
"ollama.ai": "registry.ollama.ai",
}),
gateway.WithPathInfoModifyFunc(func(info *gateway.ImageInfo) *gateway.ImageInfo {
if info.Host == "docker.io" {
info.Host = "registry-1.docker.io"
} else if info.Host == "ollama.ai" {
info.Host = "registry.ollama.ai"
}

// docker.io/busybox => docker.io/library/busybox
if info.Host == "docker.io" && !strings.Contains(info.Name, "/") {
if info.Host == "registry-1.docker.io" && !strings.Contains(info.Name, "/") {
info.Name = "library/" + info.Name
}
if info.Host == "ollama.ai" && !strings.Contains(info.Name, "/") {
if info.Host == "registry.ollama.ai" && !strings.Contains(info.Name, "/") {
info.Name = "library/" + info.Name
}
return info
}),
gateway.WithDisableTagsList(flags.DisableTagsList),
gateway.WithDefaultRegistry(flags.DefaultRegistry),
gateway.WithOverrideDefaultRegistry(flags.OverrideDefaultRegistry),
)

if flags.StorageURL != "" {
Expand Down
70 changes: 20 additions & 50 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,14 @@ type ImageInfo struct {
}

type Gateway struct {
httpClient *http.Client
modify func(info *ImageInfo) *ImageInfo
domainAlias map[string]string
logger *slog.Logger
disableTagsList bool
defaultRegistry string
overrideDefaultRegistry map[string]string
cache *cache.Cache
manifestCache maps.SyncMap[cacheKey, time.Time]
manifestCacheDuration time.Duration
authenticator *token.Authenticator
httpClient *http.Client
modify func(info *ImageInfo) *ImageInfo
logger *slog.Logger
disableTagsList bool
cache *cache.Cache
manifestCache maps.SyncMap[cacheKey, time.Time]
manifestCacheDuration time.Duration
authenticator *token.Authenticator

agent *agent.Agent
}
Expand All @@ -59,18 +56,6 @@ func WithManifestCacheDuration(d time.Duration) Option {
}
}

func WithDefaultRegistry(target string) Option {
return func(c *Gateway) {
c.defaultRegistry = target
}
}

func WithOverrideDefaultRegistry(overrideDefaultRegistry map[string]string) Option {
return func(c *Gateway) {
c.overrideDefaultRegistry = overrideDefaultRegistry
}
}

func WithDisableTagsList(b bool) Option {
return func(c *Gateway) {
c.disableTagsList = b
Expand All @@ -83,12 +68,6 @@ func WithLogger(logger *slog.Logger) Option {
}
}

func WithDomainAlias(domainAlias map[string]string) Option {
return func(c *Gateway) {
c.domainAlias = domainAlias
}
}

func WithPathInfoModifyFunc(modify func(info *ImageInfo) *ImageInfo) Option {
return func(c *Gateway) {
c.modify = modify
Expand Down Expand Up @@ -212,19 +191,23 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
}

defaultRegistry := c.defaultRegistry
if c.overrideDefaultRegistry != nil {
r, ok := c.overrideDefaultRegistry[r.Host]
if ok {
defaultRegistry = r
}
}
info, ok := parseOriginPathInfo(oriPath, defaultRegistry)
info, ok := parseOriginPathInfo(oriPath)
if !ok {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
return
}

if t.Attribute.Host != "" {
info.Host = t.Attribute.Host
}
if info.Host == "" {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
return
}
if t.Attribute.Image != "" {
info.Image = t.Attribute.Image
}

if c.modify != nil {
n := c.modify(&ImageInfo{
Host: info.Host,
Expand All @@ -239,8 +222,6 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}

info.Host = c.getDomainAlias(info.Host)

if info.Blobs != "" {
c.blob(rw, r, info, &t, authData)
return
Expand Down Expand Up @@ -330,14 +311,3 @@ func (c *Gateway) errorResponse(rw http.ResponseWriter, r *http.Request, err err

errcode.ServeJSON(rw, err)
}

func (c *Gateway) getDomainAlias(host string) string {
if c.domainAlias == nil {
return host
}
h, ok := c.domainAlias[host]
if !ok {
return host
}
return h
}
27 changes: 18 additions & 9 deletions gateway/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
return
}
r.Header = map[string][]string{
"Accept": {"application/vnd.docker.distribution.manifest.v1+json,application/vnd.docker.distribution.manifest.v1+prettyjws,application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json,application/vnd.docker.distribution.manifest.list.v2+json,application/vnd.oci.image.index.v1+json"},
}

resp, err := c.httpClient.Do(r)
if err != nil {
if c.fallbackServeCachedManifest(rw, r, info) {
return
}
c.logger.Error("failed to request", "host", info.Host, "image", info.Image, "error", err)
c.logger.Error("failed to request", "url", u, "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
return
}
Expand All @@ -48,26 +51,32 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,
switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 40x, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
c.logger.Error("origin manifest response 40x, but hit caches", "url", u, "error", err, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 40x", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
c.logger.Error("origin manifest response 40x", "url", u, "error", err, "response", dumpResponse(resp))
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
return
}

if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 4xx, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
c.logger.Error("origin manifest response 4xx, but hit caches", "url", u, "error", err, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 4xx", "url", u, "error", err, "response", dumpResponse(resp))
} else if resp.StatusCode >= http.StatusInternalServerError {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 5xx, but hit caches", "url", u, "error", err, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 4xx", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
} else if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusInternalServerError {
c.logger.Error("origin manifest response 5xx", "url", u, "error", err, "response", dumpResponse(resp))
} else if resp.StatusCode < http.StatusOK {
if c.fallbackServeCachedManifest(rw, r, info) {
c.logger.Error("origin manifest response 5xx, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
c.logger.Error("origin manifest response 1xx, but hit caches", "url", u, "error", err, "response", dumpResponse(resp))
return
}
c.logger.Error("origin manifest response 5xx", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp))
c.logger.Error("origin manifest response 1xx", "url", u, "error", err, "response", dumpResponse(resp))
}

resp.Header.Del("Docker-Ratelimit-Source")
Expand All @@ -84,7 +93,7 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request,
return
}

if resp.StatusCode >= http.StatusOK || resp.StatusCode < http.StatusMultipleChoices {
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
body, err := io.ReadAll(resp.Body)
if err != nil {
c.errorResponse(rw, r, err)
Expand Down
Loading

0 comments on commit 5bb296d

Please sign in to comment.