diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000..ad9878d --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,291 @@ +package agent + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/token" + "github.com/docker/distribution/registry/api/errcode" +) + +type BlobInfo struct { + Host string + Image string + + Blobs string +} + +type Agent struct { + mutCache sync.Map + httpClient *http.Client + logger *slog.Logger + cache *cache.Cache + authenticator *token.Authenticator +} + +type Option func(c *Agent) error + +func WithCache(cache *cache.Cache) Option { + return func(c *Agent) error { + c.cache = cache + return nil + } +} + +func WithLogger(logger *slog.Logger) Option { + return func(c *Agent) error { + c.logger = logger + return nil + } +} + +func WithAuthenticator(authenticator *token.Authenticator) Option { + return func(c *Agent) error { + c.authenticator = authenticator + return nil + } +} + +func WithClient(client *http.Client) Option { + return func(c *Agent) error { + c.httpClient = client + return nil + } +} + +func NewAgent(opts ...Option) (*Agent, error) { + c := &Agent{ + logger: slog.Default(), + httpClient: http.DefaultClient, + } + + for _, opt := range opts { + opt(c) + } + + return c, nil +} + +// /v2/{source}/{path...}/blobs/sha256:{digest} + +func parsePath(path string) (string, string, string, bool) { + path = strings.TrimPrefix(path, "/v2/") + parts := strings.Split(path, "/") + if len(parts) < 4 { + return "", "", "", false + } + if parts[len(parts)-2] != "blobs" { + return "", "", "", false + } + source := parts[0] + image := strings.Join(parts[1:len(parts)-2], "/") + digest := parts[len(parts)-1] + if !strings.HasPrefix(digest, "sha256:") { + return "", "", "", false + } + return source, image, digest, true +} + +func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + source, image, digest, ok := parsePath(r.URL.Path) + if !ok { + http.NotFound(rw, r) + return + } + + info := &BlobInfo{ + Host: source, + Image: image, + Blobs: digest, + } + + var t token.Token + var err error + if c.authenticator != nil { + t, err = c.authenticator.Authorization(r) + if err != nil { + c.errorResponse(rw, r, errcode.ErrorCodeDenied.WithMessage(err.Error())) + return + } + } + + if t.Block { + if t.BlockMessage != "" { + errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage)) + } else { + errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + } + return + } + + c.Serve(rw, r, info, &t) +} + +func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t *token.Token) { + ctx := r.Context() + + closeValue, loaded := c.mutCache.LoadOrStore(info.Blobs, make(chan struct{})) + closeCh := closeValue.(chan struct{}) + for loaded { + select { + case <-ctx.Done(): + err := ctx.Err().Error() + c.logger.Error("context done", "error", err) + http.Error(rw, err, http.StatusInternalServerError) + return + case <-closeCh: + } + closeValue, loaded = c.mutCache.LoadOrStore(info.Blobs, make(chan struct{})) + closeCh = closeValue.(chan struct{}) + } + + doneCache := func() { + c.mutCache.Delete(info.Blobs) + close(closeCh) + } + + stat, err := c.cache.StatBlob(ctx, info.Blobs) + if err == nil { + doneCache() + + size := stat.Size() + if r.Method == http.MethodHead { + rw.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + rw.Header().Set("Content-Type", "application/octet-stream") + return + } + + if !t.NoRateLimit { + sleepDuration(float64(size), float64(t.RateLimitPerSecond)) + } + + err = c.redirect(rw, r, info.Blobs, info) + if err == nil { + return + } + c.errorResponse(rw, r, ctx.Err()) + return + } + c.logger.Info("Cache miss", "digest", info.Blobs) + + type signal struct { + err error + size int64 + } + signalCh := make(chan signal, 1) + + go func() { + defer doneCache() + size, err := c.cacheBlob(r, info, func(size int64) { + signalCh <- signal{ + size: size, + } + }) + signalCh <- signal{ + err: err, + size: size, + } + }() + + select { + case <-ctx.Done(): + c.errorResponse(rw, r, ctx.Err()) + return + case signal := <-signalCh: + if signal.err != nil { + c.errorResponse(rw, r, signal.err) + return + } + if r.Method == http.MethodHead { + rw.Header().Set("Content-Length", strconv.FormatInt(signal.size, 10)) + rw.Header().Set("Content-Type", "application/octet-stream") + return + } + + if !t.NoRateLimit { + sleepDuration(float64(signal.size), float64(t.RateLimitPerSecond)) + } + + select { + case <-ctx.Done(): + return + case <-signalCh: + err = c.redirect(rw, r, info.Blobs, info) + if err != nil { + c.logger.Error("failed to redirect", "digest", info.Blobs, "error", err) + } + } + return + } +} + +func sleepDuration(size, limit float64) { + if limit <= 0 { + return + } + + sd := time.Duration(size / limit * float64(time.Second)) + if sd > time.Second/10 { + time.Sleep(sd) + } +} + +func (c *Agent) cacheBlob(r *http.Request, info *BlobInfo, stats func(int64)) (int64, error) { + resp, err := c.httpClient.Do(r.WithContext(context.Background())) + if err != nil { + return 0, err + } + defer func() { + resp.Body.Close() + }() + + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + return 0, errcode.ErrorCodeDenied + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return 0, errcode.ErrorCodeUnknown.WithMessage(fmt.Sprintf("Source response code %d", resp.StatusCode)) + } + + if stats != nil { + stats(resp.ContentLength) + } + + return c.cache.PutBlob(context.Background(), info.Blobs, resp.Body) +} + +func (c *Agent) errorResponse(rw http.ResponseWriter, r *http.Request, err error) { + if err != nil { + e := err.Error() + c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", e) + } + + if err == nil { + err = errcode.ErrorCodeUnknown + } + + errcode.ServeJSON(rw, err) +} + +func (c *Agent) redirect(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo) error { + referer := r.RemoteAddr + if info != nil { + referer += fmt.Sprintf(":%s/%s", info.Host, info.Image) + } + + u, err := c.cache.RedirectBlob(r.Context(), blob, referer, r.RemoteAddr) + if err != nil { + return err + } + c.logger.Info("Cache hit", "digest", blob, "url", u) + http.Redirect(rw, r, u, http.StatusFound) + return nil +} diff --git a/cmd/crproxy/cluster/agent/agent.go b/cmd/crproxy/cluster/agent/agent.go new file mode 100644 index 0000000..fea659a --- /dev/null +++ b/cmd/crproxy/cluster/agent/agent.go @@ -0,0 +1,170 @@ +package agent + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/daocloud/crproxy/agent" + "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/internal/pki" + "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/signing" + "github.com/daocloud/crproxy/token" + "github.com/daocloud/crproxy/transport" + "github.com/docker/distribution/registry/storage/driver/factory" + "github.com/gorilla/handlers" + "github.com/spf13/cobra" +) + +type flagpole struct { + StorageDriver string + StorageParameters map[string]string + LinkExpires time.Duration + + Userpass []string + Retry int + RetryInterval time.Duration + + Behind bool + Address string + AcmeHosts []string + AcmeCacheDir string + CertFile string + PrivateKeyFile string + + TokenPublicKeyFile string + TokenURL string +} + +func NewCommand() *cobra.Command { + flags := &flagpole{ + Address: ":18002", + } + + cmd := &cobra.Command{ + Use: "agent", + Short: "Agent", + RunE: func(cmd *cobra.Command, args []string) error { + return runE(cmd.Context(), flags) + }, + } + + cmd.Flags().StringVar(&flags.StorageDriver, "storage-driver", flags.StorageDriver, "Storage driver") + cmd.Flags().StringToStringVar(&flags.StorageParameters, "storage-parameters", flags.StorageParameters, "Storage parameters") + cmd.Flags().DurationVar(&flags.LinkExpires, "link-expires", flags.LinkExpires, "Link expires") + + cmd.Flags().StringSliceVarP(&flags.Userpass, "user", "u", flags.Userpass, "host and username and password -u user:pwd@host") + cmd.Flags().IntVar(&flags.Retry, "retry", flags.Retry, "Retry") + cmd.Flags().DurationVar(&flags.RetryInterval, "retry-interval", flags.RetryInterval, "Retry interval") + + cmd.Flags().BoolVar(&flags.Behind, "behind", flags.Behind, "Behind") + cmd.Flags().StringVar(&flags.Address, "address", flags.Address, "Address") + cmd.Flags().StringSliceVar(&flags.AcmeHosts, "acme-hosts", flags.AcmeHosts, "Acme hosts") + cmd.Flags().StringVar(&flags.AcmeCacheDir, "acme-cache-dir", flags.AcmeCacheDir, "Acme cache dir") + cmd.Flags().StringVar(&flags.CertFile, "cert-file", flags.CertFile, "Cert file") + cmd.Flags().StringVar(&flags.PrivateKeyFile, "private-key-file", flags.PrivateKeyFile, "Private key file") + + cmd.Flags().StringVar(&flags.TokenPublicKeyFile, "token-public-key-file", flags.TokenPublicKeyFile, "Token public key file") + cmd.Flags().StringVar(&flags.TokenURL, "token-url", flags.TokenURL, "Token url") + + return cmd +} + +func runE(ctx context.Context, flags *flagpole) error { + mux := http.NewServeMux() + + opts := []agent.Option{} + + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + + cacheOpts := []cache.Option{} + + parameters := map[string]interface{}{} + for k, v := range flags.StorageParameters { + parameters[k] = v + } + sd, err := factory.Create(flags.StorageDriver, parameters) + if err != nil { + return fmt.Errorf("create storage driver failed: %w", err) + } + cacheOpts = append(cacheOpts, cache.WithStorageDriver(sd)) + if flags.LinkExpires > 0 { + cacheOpts = append(cacheOpts, cache.WithLinkExpires(flags.LinkExpires)) + } + + cache, err := cache.NewCache(cacheOpts...) + if err != nil { + return fmt.Errorf("create cache failed: %w", err) + } + + opts = append(opts, + agent.WithCache(cache), + agent.WithLogger(logger), + ) + + if flags.TokenPublicKeyFile != "" { + publicKeyData, err := os.ReadFile(flags.TokenPublicKeyFile) + if err != nil { + return fmt.Errorf("failed to read token public key file: %w", err) + } + publicKey, err := pki.DecodePublicKey(publicKeyData) + if err != nil { + return fmt.Errorf("failed to decode token public key: %w", err) + } + + authenticator := token.NewAuthenticator(token.NewDecoder(signing.NewVerifier(publicKey)), flags.TokenURL) + opts = append(opts, agent.WithAuthenticator(authenticator)) + } + + transportOpts := []transport.Option{ + transport.WithLogger(logger), + } + + tp, err := transport.NewTransport(transportOpts...) + if err != nil { + return fmt.Errorf("create clientset failed: %w", err) + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) > 10 { + return http.ErrUseLastResponse + } + s := make([]string, 0, len(via)+1) + for _, v := range via { + s = append(s, v.URL.String()) + } + + lastRedirect := req.URL.String() + s = append(s, lastRedirect) + logger.Info("redirect", "redirects", s) + + return nil + }, + Transport: tp, + } + opts = append(opts, agent.WithClient(client)) + + a, err := agent.NewAgent(opts...) + if err != nil { + return fmt.Errorf("create agent failed: %w", err) + } + + mux.Handle("/v2/", a) + + var handler http.Handler = mux + handler = handlers.LoggingHandler(os.Stderr, handler) + if flags.Behind { + handler = handlers.ProxyHeaders(handler) + } + + err = server.Run(ctx, flags.Address, handler, flags.AcmeHosts, flags.AcmeCacheDir, flags.CertFile, flags.PrivateKeyFile) + if err != nil { + return fmt.Errorf("failed to run server: %w", err) + } + return nil +} diff --git a/cmd/crproxy/cluster/auth/auth.go b/cmd/crproxy/cluster/auth/auth.go new file mode 100644 index 0000000..f04682c --- /dev/null +++ b/cmd/crproxy/cluster/auth/auth.go @@ -0,0 +1,165 @@ +package auth + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "net/url" + "os" + "sync/atomic" + + "github.com/daocloud/crproxy/internal/pki" + "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/signing" + "github.com/daocloud/crproxy/token" + "github.com/gorilla/handlers" + "github.com/spf13/cobra" +) + +type flagpole struct { + Behind bool + Address string + AcmeHosts []string + AcmeCacheDir string + CertFile string + PrivateKeyFile string + + TokenPrivateKeyFile string + TokenPublicKeyFile string + + SimpleAuthUserpass map[string]string + + AllowAnonymous bool + AnonymousRateLimitPerSecond uint64 + + BlobsURLs []string +} + +func NewCommand() *cobra.Command { + flags := &flagpole{ + Address: ":18000", + } + + cmd := &cobra.Command{ + Use: "auth", + Short: "Auth", + RunE: func(cmd *cobra.Command, args []string) error { + return runE(cmd.Context(), flags) + }, + } + + cmd.Flags().BoolVar(&flags.Behind, "behind", flags.Behind, "Behind") + cmd.Flags().StringVar(&flags.Address, "address", flags.Address, "Address") + cmd.Flags().StringSliceVar(&flags.AcmeHosts, "acme-hosts", flags.AcmeHosts, "Acme hosts") + cmd.Flags().StringVar(&flags.AcmeCacheDir, "acme-cache-dir", flags.AcmeCacheDir, "Acme cache dir") + cmd.Flags().StringVar(&flags.CertFile, "cert-file", flags.CertFile, "Cert file") + cmd.Flags().StringVar(&flags.PrivateKeyFile, "private-key-file", flags.PrivateKeyFile, "Private key file") + + cmd.Flags().StringVar(&flags.TokenPrivateKeyFile, "token-private-key-file", "", "private key file") + cmd.Flags().StringVar(&flags.TokenPublicKeyFile, "token-public-key-file", "", "public key file") + + cmd.Flags().StringToStringVar(&flags.SimpleAuthUserpass, "simple-auth-userpass", flags.SimpleAuthUserpass, "Simple auth userpass") + + cmd.Flags().BoolVar(&flags.AllowAnonymous, "allow-anonymous", flags.AllowAnonymous, "Allow anonymous") + + cmd.Flags().StringSliceVar(&flags.BlobsURLs, "blobs-url", flags.BlobsURLs, "Blobs urls") + + return cmd +} + +func runE(ctx context.Context, flags *flagpole) error { + mux := http.NewServeMux() + + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + + privateKeyData, err := os.ReadFile(flags.TokenPrivateKeyFile) + if err != nil { + logger.Error("failed to ReadFile", "file", flags.TokenPrivateKeyFile, "error", err) + os.Exit(1) + } + privateKey, err := pki.DecodePrivateKey(privateKeyData) + if err != nil { + logger.Error("failed to DecodePrivateKey", "file", flags.TokenPrivateKeyFile, "error", err) + os.Exit(1) + } + + if flags.TokenPublicKeyFile != "" { + publicKeyData, err := pki.EncodePublicKey(&privateKey.PublicKey) + if err != nil { + return fmt.Errorf("failed to encode public key: %w", err) + } + + err = os.WriteFile(flags.TokenPublicKeyFile, publicKeyData, 0644) + if err != nil { + return fmt.Errorf("failed to write token public key file: %w", err) + } + } + + getHosts := getBlobsURLs(flags.BlobsURLs) + + authFunc := func(r *http.Request, userinfo *url.Userinfo, t *token.Token) (token.Attribute, bool) { + if userinfo == nil { + if !flags.AllowAnonymous { + return token.Attribute{}, false + } + attr := token.Attribute{ + RateLimitPerSecond: flags.AnonymousRateLimitPerSecond, + } + + if !attr.Block { + attr.BlobsURL = getHosts() + } + return attr, true + } + if flags.SimpleAuthUserpass == nil { + return token.Attribute{}, false + } + pass, ok := flags.SimpleAuthUserpass[userinfo.Username()] + if !ok { + return token.Attribute{}, false + } + upass, ok := userinfo.Password() + if !ok { + return token.Attribute{}, false + } + if upass != pass { + return token.Attribute{}, false + } + return token.Attribute{ + NoRateLimit: true, + NoAllowlist: true, + NoBlock: true, + AllowTagsList: true, + BlobsURL: getHosts(), + }, true + } + + gen := token.NewGenerator(token.NewEncoder(signing.NewSigner(privateKey)), authFunc, logger) + mux.Handle("/auth/token", gen) + + var handler http.Handler = mux + handler = handlers.LoggingHandler(os.Stderr, handler) + if flags.Behind { + handler = handlers.ProxyHeaders(handler) + } + + err = server.Run(ctx, flags.Address, handler, flags.AcmeHosts, flags.AcmeCacheDir, flags.CertFile, flags.PrivateKeyFile) + if err != nil { + return fmt.Errorf("failed to run server: %w", err) + } + return nil +} + +func getBlobsURLs(urls []string) func() string { + if len(urls) == 0 { + return func() string { + return "" + } + } + var index uint64 + return func() string { + n := atomic.AddUint64(&index, 1) + return urls[n%uint64(len(urls))] + } +} diff --git a/cmd/crproxy/cluster/cluster.go b/cmd/crproxy/cluster/cluster.go new file mode 100644 index 0000000..daeab36 --- /dev/null +++ b/cmd/crproxy/cluster/cluster.go @@ -0,0 +1,24 @@ +package cluster + +import ( + "github.com/spf13/cobra" + + "github.com/daocloud/crproxy/cmd/crproxy/cluster/agent" + "github.com/daocloud/crproxy/cmd/crproxy/cluster/auth" + "github.com/daocloud/crproxy/cmd/crproxy/cluster/gateway" +) + +func NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Args: cobra.NoArgs, + Use: "cluster", + Short: "Cluster commands", + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Usage() + }, + } + cmd.AddCommand(agent.NewCommand()) + cmd.AddCommand(gateway.NewCommand()) + cmd.AddCommand(auth.NewCommand()) + return cmd +} diff --git a/cmd/crproxy/cluster/gateway/gateway.go b/cmd/crproxy/cluster/gateway/gateway.go new file mode 100644 index 0000000..8732fdc --- /dev/null +++ b/cmd/crproxy/cluster/gateway/gateway.go @@ -0,0 +1,209 @@ +package gateway + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "os" + "strings" + "time" + + "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/gateway" + "github.com/daocloud/crproxy/internal/pki" + "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/signing" + "github.com/daocloud/crproxy/token" + "github.com/daocloud/crproxy/transport" + "github.com/docker/distribution/registry/storage/driver/factory" + "github.com/gorilla/handlers" + "github.com/spf13/cobra" +) + +type flagpole struct { + StorageDriver string + StorageParameters map[string]string + + ManifestCacheDuration time.Duration + + Userpass []string + Retry int + RetryInterval time.Duration + DisableTagsList bool + + Behind bool + Address string + AcmeHosts []string + AcmeCacheDir string + CertFile string + PrivateKeyFile string + + TokenPublicKeyFile string + TokenURL string + + DefaultRegistry string + OverrideDefaultRegistry map[string]string + + ReadmeURL string +} + +func NewCommand() *cobra.Command { + flags := &flagpole{ + Address: ":18001", + } + + cmd := &cobra.Command{ + Use: "gateway", + Short: "Gateway", + RunE: func(cmd *cobra.Command, args []string) error { + return runE(cmd.Context(), flags) + }, + } + + cmd.Flags().StringVar(&flags.StorageDriver, "storage-driver", flags.StorageDriver, "Storage driver") + cmd.Flags().StringToStringVar(&flags.StorageParameters, "storage-parameters", flags.StorageParameters, "Storage parameters") + + cmd.Flags().DurationVar(&flags.ManifestCacheDuration, "manifest-cache-duration", flags.ManifestCacheDuration, "Manifest cache duration") + + cmd.Flags().StringSliceVarP(&flags.Userpass, "user", "u", flags.Userpass, "host and username and password -u user:pwd@host") + cmd.Flags().IntVar(&flags.Retry, "retry", flags.Retry, "Retry") + cmd.Flags().DurationVar(&flags.RetryInterval, "retry-interval", flags.RetryInterval, "Retry interval") + cmd.Flags().BoolVar(&flags.DisableTagsList, "disable-tags-list", flags.DisableTagsList, "Disable tags list") + + cmd.Flags().BoolVar(&flags.Behind, "behind", flags.Behind, "Behind") + cmd.Flags().StringVar(&flags.Address, "address", flags.Address, "Address") + cmd.Flags().StringSliceVar(&flags.AcmeHosts, "acme-hosts", flags.AcmeHosts, "Acme hosts") + cmd.Flags().StringVar(&flags.AcmeCacheDir, "acme-cache-dir", flags.AcmeCacheDir, "Acme cache dir") + cmd.Flags().StringVar(&flags.CertFile, "cert-file", flags.CertFile, "Cert file") + cmd.Flags().StringVar(&flags.PrivateKeyFile, "private-key-file", flags.PrivateKeyFile, "Private key file") + + cmd.Flags().StringVar(&flags.TokenPublicKeyFile, "token-public-key-file", flags.TokenPublicKeyFile, "Token public key file") + cmd.Flags().StringVar(&flags.TokenURL, "token-url", flags.TokenURL, "Token url") + + cmd.Flags().StringVar(&flags.DefaultRegistry, "default-registry", flags.DefaultRegistry, "Default registry") + cmd.Flags().StringToStringVar(&flags.OverrideDefaultRegistry, "override-default-registry", flags.OverrideDefaultRegistry, "Override default registry") + + cmd.Flags().StringVar(&flags.ReadmeURL, "readme-url", flags.ReadmeURL, "Readme url") + return cmd +} + +func runE(ctx context.Context, flags *flagpole) error { + mux := http.NewServeMux() + + opts := []gateway.Option{} + + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + + opts = append(opts, + gateway.WithLogger(logger), + gateway.WithDomainAlias(map[string]string{ + "docker.io": "registry-1.docker.io", + "ollama.ai": "registry.ollama.ai", + }), + gateway.WithPathInfoModifyFunc(func(info *gateway.ImageInfo) *gateway.ImageInfo { + // docker.io/busybox => docker.io/library/busybox + if info.Host == "docker.io" && !strings.Contains(info.Name, "/") { + info.Name = "library/" + info.Name + } + if info.Host == "ollama.ai" && !strings.Contains(info.Name, "/") { + info.Name = "library/" + info.Name + } + return info + }), + gateway.WithDisableTagsList(flags.DisableTagsList), + gateway.WithDefaultRegistry(flags.DefaultRegistry), + gateway.WithOverrideDefaultRegistry(flags.OverrideDefaultRegistry), + ) + + if flags.StorageDriver != "" { + cacheOpts := []cache.Option{} + + parameters := map[string]interface{}{} + for k, v := range flags.StorageParameters { + parameters[k] = v + } + sd, err := factory.Create(flags.StorageDriver, parameters) + if err != nil { + return fmt.Errorf("create storage driver failed: %w", err) + } + cacheOpts = append(cacheOpts, cache.WithStorageDriver(sd)) + + cache, err := cache.NewCache(cacheOpts...) + if err != nil { + return fmt.Errorf("create cache failed: %w", err) + } + opts = append(opts, gateway.WithCache(cache)) + opts = append(opts, gateway.WithManifestCacheDuration(flags.ManifestCacheDuration)) + } + + if flags.TokenPublicKeyFile != "" { + if flags.TokenURL == "" { + return fmt.Errorf("token url is required") + } + publicKeyData, err := os.ReadFile(flags.TokenPublicKeyFile) + if err != nil { + return fmt.Errorf("failed to read token public key file: %w", err) + } + publicKey, err := pki.DecodePublicKey(publicKeyData) + if err != nil { + return fmt.Errorf("failed to decode token public key: %w", err) + } + + authenticator := token.NewAuthenticator(token.NewDecoder(signing.NewVerifier(publicKey)), flags.TokenURL) + opts = append(opts, gateway.WithAuthenticator(authenticator)) + } + + transportOpts := []transport.Option{ + transport.WithLogger(logger), + } + + tp, err := transport.NewTransport(transportOpts...) + if err != nil { + return fmt.Errorf("create clientset failed: %w", err) + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) > 10 { + return http.ErrUseLastResponse + } + s := make([]string, 0, len(via)+1) + for _, v := range via { + s = append(s, v.URL.String()) + } + + lastRedirect := req.URL.String() + s = append(s, lastRedirect) + logger.Info("redirect", "redirects", s) + + return nil + }, + Transport: tp, + } + opts = append(opts, gateway.WithClient(client)) + + a, err := gateway.NewGateway(opts...) + if err != nil { + return fmt.Errorf("create gateway failed: %w", err) + } + + if flags.ReadmeURL != "" { + mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + http.Redirect(rw, r, flags.ReadmeURL, http.StatusFound) + }) + } + mux.Handle("/v2/", a) + + var handler http.Handler = mux + handler = handlers.LoggingHandler(os.Stderr, handler) + if flags.Behind { + handler = handlers.ProxyHeaders(handler) + } + + err = server.Run(ctx, flags.Address, handler, flags.AcmeHosts, flags.AcmeCacheDir, flags.CertFile, flags.PrivateKeyFile) + if err != nil { + return fmt.Errorf("failed to run server: %w", err) + } + return nil +} diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index bbdf50b..40ae029 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -21,6 +21,7 @@ import ( "time" "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/cmd/crproxy/cluster" csync "github.com/daocloud/crproxy/cmd/crproxy/sync" "github.com/daocloud/crproxy/transport" "github.com/docker/distribution/registry/storage/driver/factory" @@ -144,6 +145,7 @@ func init() { pflag.StringVar(&tokenPublicKeyFile, "token-public-key-file", "", "public key file") cmd.AddCommand(csync.NewCommand()) + cmd.AddCommand(cluster.NewCommand()) } var ( @@ -452,10 +454,10 @@ func run(ctx context.Context) { opts = append(opts, crproxy.WithOverrideDefaultRegistry(overrideDefaultRegistry)) } - var auth func(r *http.Request, userinfo *url.Userinfo) (token.Attribute, bool) + var authFunc func(r *http.Request, userinfo *url.Userinfo, t *token.Token) (token.Attribute, bool) if len(simpleAuthUserpass) != 0 { - auth = func(r *http.Request, userinfo *url.Userinfo) (token.Attribute, bool) { + authFunc = func(r *http.Request, userinfo *url.Userinfo, t *token.Token) (token.Attribute, bool) { if userinfo == nil { return token.Attribute{}, simpleAuthAllowAnonymous } @@ -526,7 +528,7 @@ func run(ctx context.Context) { opts = append(opts, crproxy.WithAuthenticator(authenticator)) if privateKey != nil { - gen := token.NewGenerator(token.NewEncoder(signing.NewSigner(privateKey)), auth, logger) + gen := token.NewGenerator(token.NewEncoder(signing.NewSigner(privateKey)), authFunc, logger) mux.Handle("/auth/token", gen) } } diff --git a/gateway/blob.go b/gateway/blob.go new file mode 100644 index 0000000..edd260b --- /dev/null +++ b/gateway/blob.go @@ -0,0 +1,29 @@ +package gateway + +import ( + "fmt" + "net/http" + + "github.com/daocloud/crproxy/agent" + "github.com/daocloud/crproxy/token" +) + +func (c *Gateway) blob(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { + if t.Attribute.BlobsURL != "" { + referer := r.RemoteAddr + blobURL := fmt.Sprintf("%s/v2/%s/%s/blobs/%s?referer=%s", t.Attribute.BlobsURL, info.Host, info.Image, info.Blobs, referer) + http.Redirect(rw, r, blobURL, http.StatusTemporaryRedirect) + return + } + + if c.agent != nil { + c.agent.Serve(rw, r, &agent.BlobInfo{ + Host: info.Host, + Image: info.Image, + Blobs: info.Blobs, + }, t) + return + } + + c.forward(rw, r, info, t) +} diff --git a/gateway/gateway.go b/gateway/gateway.go new file mode 100644 index 0000000..d74c471 --- /dev/null +++ b/gateway/gateway.go @@ -0,0 +1,340 @@ +package gateway + +import ( + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/textproto" + "net/url" + "strings" + "time" + + "github.com/daocloud/crproxy/agent" + "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/internal/maps" + "github.com/daocloud/crproxy/token" + "github.com/docker/distribution/registry/api/errcode" + "github.com/wzshiming/geario" +) + +var ( + prefix = "/v2/" + catalog = prefix + "_catalog" +) + +type ImageInfo struct { + Host string + Name string +} + +type Gateway struct { + httpClient *http.Client + modify func(info *ImageInfo) *ImageInfo + domainAlias map[string]string + logger *slog.Logger + disableTagsList bool + defaultRegistry string + overrideDefaultRegistry map[string]string + cache *cache.Cache + manifestCache maps.SyncMap[cacheKey, time.Time] + manifestCacheDuration time.Duration + authenticator *token.Authenticator + + agent *agent.Agent +} + +type Option func(c *Gateway) + +func WithClient(client *http.Client) Option { + return func(c *Gateway) { + c.httpClient = client + } +} + +func WithManifestCacheDuration(d time.Duration) Option { + return func(c *Gateway) { + c.manifestCacheDuration = d + } +} + +func WithDefaultRegistry(target string) Option { + return func(c *Gateway) { + c.defaultRegistry = target + } +} + +func WithOverrideDefaultRegistry(overrideDefaultRegistry map[string]string) Option { + return func(c *Gateway) { + c.overrideDefaultRegistry = overrideDefaultRegistry + } +} + +func WithDisableTagsList(b bool) Option { + return func(c *Gateway) { + c.disableTagsList = b + } +} + +func WithLogger(logger *slog.Logger) Option { + return func(c *Gateway) { + c.logger = logger + } +} + +func WithDomainAlias(domainAlias map[string]string) Option { + return func(c *Gateway) { + c.domainAlias = domainAlias + } +} + +func WithPathInfoModifyFunc(modify func(info *ImageInfo) *ImageInfo) Option { + return func(c *Gateway) { + c.modify = modify + } +} + +func WithAuthenticator(authenticator *token.Authenticator) Option { + return func(c *Gateway) { + c.authenticator = authenticator + } +} + +func WithCache(cache *cache.Cache) Option { + return func(c *Gateway) { + c.cache = cache + } +} + +func NewGateway(opts ...Option) (*Gateway, error) { + c := &Gateway{ + logger: slog.Default(), + } + + for _, opt := range opts { + opt(c) + } + + if c.authenticator == nil { + return nil, fmt.Errorf("no authenticator provided") + } + + if c.cache != nil { + a, err := agent.NewAgent( + agent.WithClient(c.httpClient), + agent.WithAuthenticator(c.authenticator), + agent.WithLogger(c.logger), + agent.WithCache(c.cache), + ) + if err != nil { + return nil, fmt.Errorf("failed to create agent: %w", err) + } + c.agent = a + } + return c, nil +} + +func apiBase(w http.ResponseWriter, r *http.Request) { + const emptyJSON = "{}" + // Provide a simple /v2/ 200 OK response with empty json response. + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", fmt.Sprint(len(emptyJSON))) + + fmt.Fprint(w, emptyJSON) +} + +func emptyTagsList(w http.ResponseWriter, r *http.Request) { + const emptyTagsList = `{"name":"disable-list-tags","tags":[]}` + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", fmt.Sprint(len(emptyTagsList))) + fmt.Fprint(w, emptyTagsList) +} + +func getIP(str string) string { + host, _, err := net.SplitHostPort(str) + if err == nil && host != "" { + return host + } + return str +} + +func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + oriPath := r.URL.Path + if !strings.HasPrefix(oriPath, prefix) { + http.NotFound(rw, r) + return + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + return + } + + if oriPath == catalog { + errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + return + } + + r.RemoteAddr = getIP(r.RemoteAddr) + + var t token.Token + var err error + if c.authenticator != nil { + t, err = c.authenticator.Authorization(r) + if err != nil { + c.authenticator.Authenticate(rw, r) + return + } + } + + if oriPath == prefix { + apiBase(rw, r) + return + } + + if c.authenticator != nil { + if t.Scope == "" { + c.authenticator.Authenticate(rw, r) + return + } + if t.Block { + if t.BlockMessage != "" { + errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage)) + } else { + errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + } + return + } + } + + defaultRegistry := c.defaultRegistry + if c.overrideDefaultRegistry != nil { + r, ok := c.overrideDefaultRegistry[r.Host] + if ok { + defaultRegistry = r + } + } + info, ok := parseOriginPathInfo(oriPath, defaultRegistry) + if !ok { + errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + return + } + + if c.modify != nil { + n := c.modify(&ImageInfo{ + Host: info.Host, + Name: info.Image, + }) + info.Host = n.Host + info.Image = n.Name + } + + if c.disableTagsList && info.TagsList && !t.AllowTagsList { + emptyTagsList(rw, r) + return + } + + info.Host = c.getDomainAlias(info.Host) + + if info.Blobs != "" { + c.blob(rw, r, info, &t) + return + } + + if info.Manifests != "" { + if c.cache != nil { + c.cacheManifestResponse(rw, r, info, &t) + return + } + } + c.forward(rw, r, info, &t) +} + +func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { + path, err := info.Path() + if err != nil { + c.logger.Warn("failed to get path", "error", err) + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + u := url.URL{ + Scheme: "https", + Host: info.Host, + Path: path, + } + r, err = http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil) + if err != nil { + c.logger.Warn("failed to new request", "error", err) + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + + resp, err := c.httpClient.Do(r) + if err != nil { + c.logger.Warn("failed to request", "host", info.Host, "image", info.Image, "error", err) + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + defer func() { + resp.Body.Close() + }() + + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + c.logger.Warn("origin direct response 40x", "host", info.Host, "image", info.Image, "response", dumpResponse(resp)) + errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + return + } + + resp.Header.Del("Docker-Ratelimit-Source") + + if resp.StatusCode == http.StatusOK { + oldLink := resp.Header.Get("Link") + if oldLink != "" { + resp.Header.Set("Link", addPrefixToImageForPagination(oldLink, info.Host)) + } + } + + header := rw.Header() + for k, v := range resp.Header { + key := textproto.CanonicalMIMEHeaderKey(k) + header[key] = v + } + rw.WriteHeader(resp.StatusCode) + + if r.Method != http.MethodHead { + var body io.Reader = resp.Body + + if t.RateLimitPerSecond > 0 { + body = geario.NewGear(time.Second, geario.B(t.RateLimitPerSecond)).Reader(body) + } + + io.Copy(rw, body) + } +} + +func (c *Gateway) errorResponse(rw http.ResponseWriter, r *http.Request, err error) { + if err != nil { + e := err.Error() + c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", e) + } + + if err == nil { + err = errcode.ErrorCodeUnknown + } + + errcode.ServeJSON(rw, err) +} + +func (c *Gateway) getDomainAlias(host string) string { + if c.domainAlias == nil { + return host + } + h, ok := c.domainAlias[host] + if !ok { + return host + } + return h +} diff --git a/gateway/manifest.go b/gateway/manifest.go new file mode 100644 index 0000000..4770bf3 --- /dev/null +++ b/gateway/manifest.go @@ -0,0 +1,152 @@ +package gateway + +import ( + "context" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" + "time" + + "github.com/daocloud/crproxy/token" + "github.com/docker/distribution/registry/api/errcode" +) + +func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { + if c.tryFirstServeCachedManifest(rw, r, info) { + return + } + + resp, err := c.httpClient.Do(r.WithContext(context.Background())) + if err != nil { + if c.fallbackServeCachedManifest(rw, r, info) { + return + } + c.logger.Error("failed to request", "host", info.Host, "image", info.Image, "error", err) + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + defer func() { + resp.Body.Close() + }() + + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + if c.fallbackServeCachedManifest(rw, r, info) { + c.logger.Error("origin manifest response 40x, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + return + } + c.logger.Error("origin manifest response 40x", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + return + } + + if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { + if c.fallbackServeCachedManifest(rw, r, info) { + c.logger.Error("origin manifest response 4xx, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + return + } + c.logger.Error("origin manifest response 4xx", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + } else if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusInternalServerError { + if c.fallbackServeCachedManifest(rw, r, info) { + c.logger.Error("origin manifest response 5xx, but hit caches", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + return + } + c.logger.Error("origin manifest response 5xx", "host", info.Host, "image", info.Image, "error", err, "response", dumpResponse(resp)) + } + + resp.Header.Del("Docker-Ratelimit-Source") + + header := rw.Header() + for k, v := range resp.Header { + key := textproto.CanonicalMIMEHeaderKey(k) + header[key] = v + } + + rw.WriteHeader(resp.StatusCode) + + if r.Method == http.MethodHead { + return + } + + if resp.StatusCode >= http.StatusOK || resp.StatusCode < http.StatusMultipleChoices { + body, err := io.ReadAll(resp.Body) + if err != nil { + c.errorResponse(rw, r, err) + return + } + + _, _, err = c.cache.PutManifestContent(context.Background(), info.Host, info.Image, info.Manifests, body) + if err != nil { + c.errorResponse(rw, r, err) + return + } + rw.Write(body) + } else { + io.Copy(rw, resp.Body) + } +} + +func (c *Gateway) tryFirstServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool { + isHash := strings.HasPrefix(info.Manifests, "sha256:") + + if !isHash && c.manifestCacheDuration > 0 { + last, ok := c.manifestCache.Load(manifestCacheKey(info)) + if !ok { + return false + } + + if time.Since(last) > c.manifestCacheDuration { + return false + } + } + + return c.serveCachedManifest(rw, r, info) +} + +func (c *Gateway) fallbackServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool { + isHash := strings.HasPrefix(info.Manifests, "sha256:") + if isHash { + return false + } + + return c.serveCachedManifest(rw, r, info) +} + +func (c *Gateway) serveCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) bool { + ctx := r.Context() + + content, digest, mediaType, err := c.cache.GetManifestContent(ctx, info.Host, info.Image, info.Manifests) + if err != nil { + c.logger.Error("Manifest cache missed", "error", err) + return false + } + + c.logger.Info("Manifest blob cache hit", "digest", digest) + rw.Header().Set("Docker-Content-Digest", digest) + rw.Header().Set("Content-Type", mediaType) + rw.Header().Set("Content-Length", strconv.FormatInt(int64(len(content)), 10)) + if r.Method != http.MethodHead { + rw.Write(content) + } + + if c.manifestCacheDuration > 0 { + c.manifestCache.Store(manifestCacheKey(info), time.Now()) + } + return true +} + +type cacheKey struct { + Host string + Image string + Digest string +} + +func manifestCacheKey(info *PathInfo) cacheKey { + return cacheKey{ + Host: info.Host, + Image: info.Image, + Digest: info.Manifests, + } +} diff --git a/gateway/utils.go b/gateway/utils.go new file mode 100644 index 0000000..7a250bd --- /dev/null +++ b/gateway/utils.go @@ -0,0 +1,167 @@ +package gateway + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +func addPrefixToImageForPagination(oldLink string, host string) string { + linkAndRel := strings.SplitN(oldLink, ";", 2) + if len(linkAndRel) != 2 { + return oldLink + } + linkURL := strings.SplitN(strings.Trim(linkAndRel[0], "<>"), "/v2/", 2) + if len(linkURL) != 2 { + return oldLink + } + mirrorPath := prefix + host + "/" + linkURL[1] + return fmt.Sprintf("<%s>;%s", mirrorPath, linkAndRel[1]) +} + +type PathInfo struct { + Host string + Image string + + TagsList bool + Manifests string + Blobs string +} + +func (p PathInfo) Path() (string, error) { + if p.TagsList { + return prefix + p.Image + "/tags/list", nil + } + if p.Manifests != "" { + return prefix + p.Image + "/manifests/" + p.Manifests, nil + } + if p.Blobs != "" { + return prefix + p.Image + "/blobs/" + p.Blobs, nil + } + return "", fmt.Errorf("unknow kind %#v", p) +} + +func parseOriginPathInfo(path string, defaultRegistry string) (*PathInfo, bool) { + path = strings.TrimPrefix(path, prefix) + i := strings.IndexByte(path, '/') + if i <= 0 { + return nil, false + } + host := path[:i] + tail := path[i+1:] + + var tails = []string{} + var image = "" + + if !isDomainName(host) || !strings.Contains(host, ".") { + // disable while non default registry seted. + if defaultRegistry == "" { + return nil, false + } + // if host is not a domain name, it is a image. + tails = strings.Split(tail, "/") + if len(tails) < 2 { + // should be more then 2 parts. like /manifests/latest + return nil, false + } + image = strings.Join(tails[:len(tails)-2], "/") + if image == "" { + // the url looks like /v2/[busybox]/manifests/latest. + image = host + } else { + // the url looks like /v2/[pytorch/pytorch/...]/[manifests/latest]. + image = host + "/" + image + } + host = defaultRegistry + } else { + + tails = strings.Split(tail, "/") + if len(tails) < 3 { + return nil, false + } + image = strings.Join(tails[:len(tails)-2], "/") + if image == "" { + return nil, false + } + } + + info := &PathInfo{ + Host: host, + Image: image, + } + switch tails[len(tails)-2] { + case "tags": + info.TagsList = tails[len(tails)-1] == "list" + case "manifests": + info.Manifests = tails[len(tails)-1] + case "blobs": + info.Blobs = tails[len(tails)-1] + if len(info.Blobs) != 7+64 { + return nil, false + } + } + return info, true +} + +// isDomainName checks if a string is a presentation-format domain name +// (currently restricted to hostname-compatible "preferred name" LDH labels and +// SRV-like "underscore labels"; see golang.org/issue/12421). +func isDomainName(s string) bool { + // See RFC 1035, RFC 3696. + // Presentation format has dots before every label except the first, and the + // terminal empty label is optional here because we assume fully-qualified + // (absolute) input. We must therefore reserve space for the first and last + // labels' length octets in wire format, where they are necessary and the + // maximum total length is 255. + // So our _effective_ maximum is 253, but 254 is not rejected if the last + // character is a dot. + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + + last := byte('.') + nonNumeric := false // true once we've seen a letter or hyphen + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return nonNumeric +} + +func dumpResponse(resp *http.Response) string { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Sprintf("%d %d %q", resp.StatusCode, resp.ContentLength, string(body)) +} diff --git a/gateway/utils_test.go b/gateway/utils_test.go new file mode 100644 index 0000000..acddf9d --- /dev/null +++ b/gateway/utils_test.go @@ -0,0 +1,126 @@ +package gateway + +import ( + "reflect" + "testing" +) + +func TestParseOriginPathInfo(t *testing.T) { + + testDefaultRegistry := "non_docker.io" + + type args struct { + path string + } + tests := []struct { + name string + args args + defaultRegistry string + want *PathInfo + wantOk bool + }{ + { + args: args{ + path: "/v2/busybox/manifests/1", + }, + defaultRegistry: testDefaultRegistry, + want: &PathInfo{ + Host: testDefaultRegistry, + Image: "busybox", + Manifests: "1", + }, + wantOk: true, + }, + { + args: args{ + path: "/v2/pytorch/pytorch/manifests/1", + }, + defaultRegistry: testDefaultRegistry, + want: &PathInfo{ + Host: testDefaultRegistry, + Image: "pytorch/pytorch", + Manifests: "1", + }, + wantOk: true, + }, + { + args: args{ + path: "/v2/v2/manifests/latest", + }, + defaultRegistry: testDefaultRegistry, + want: &PathInfo{ + Host: testDefaultRegistry, + Image: "v2", + Manifests: "latest", + }, + wantOk: true, + }, + { + args: args{ + path: "/v2/docker.io/busybox/manifests/1", + }, + want: &PathInfo{ + Host: "docker.io", + Image: "busybox", + Manifests: "1", + }, + wantOk: true, + }, + { + args: args{ + path: "/v2/docker.io/library/busybox/manifests/1", + }, + want: &PathInfo{ + Host: "docker.io", + Image: "library/busybox", + Manifests: "1", + }, + wantOk: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotOk := parseOriginPathInfo(tt.args.path, tt.defaultRegistry) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseOriginPathInfo() got = %v, want %v", got, tt.want) + } + if gotOk != tt.wantOk { + t.Errorf("ParseOriginPathInfo() gotOk = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} + +func Test_addPrefixToImageForPagination(t *testing.T) { + type args struct { + oldLink string + host string + } + tests := []struct { + name string + args args + want string + }{ + { + args: args{ + oldLink: "; ref=other", + host: "prefix", + }, + want: "; ref=other", + }, + { + args: args{ + oldLink: "; ref=other", + host: "prefix", + }, + want: "; ref=other", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := addPrefixToImageForPagination(tt.args.oldLink, tt.args.host); got != tt.want { + t.Errorf("addPrefixToImageForPagination() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/token/encoding.go b/token/encoding.go index 77d42a9..f61baa2 100644 --- a/token/encoding.go +++ b/token/encoding.go @@ -35,14 +35,23 @@ type Token struct { Account string `json:"account,omitempty"` Image string `json:"image,omitempty"` + IP string `json:"ip,omitempty"` + Attribute `json:"attribute,omitempty"` } type Attribute struct { - NoRateLimit bool `json:"no_rate_limit,omitempty"` + NoRateLimit bool `json:"no_rate_limit,omitempty"` + RateLimitPerSecond uint64 `json:"rate_limit_per_second,omitempty"` + NoAllowlist bool `json:"no_allowlist,omitempty"` NoBlock bool `json:"no_block,omitempty"` AllowTagsList bool `json:"allow_tags_list,omitempty"` + + BlobsURL string `json:"blobs_url,omitempty"` + + Block bool `json:"block,omitempty"` + BlockMessage string `json:"block_message,omitempty"` } func (p *Encoder) Encode(t Token) (code string, err error) { diff --git a/token/generator.go b/token/generator.go index 562764b..88f76e7 100644 --- a/token/generator.go +++ b/token/generator.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "log/slog" + "net" "net/http" "net/url" "strings" @@ -13,14 +14,14 @@ import ( ) type Generator struct { - authFunc func(r *http.Request, userinfo *url.Userinfo) (Attribute, bool) + authFunc func(r *http.Request, userinfo *url.Userinfo, t *Token) (Attribute, bool) logger *slog.Logger tokenEncoder *Encoder } func NewGenerator( tokenEncoder *Encoder, - authFunc func(r *http.Request, userinfo *url.Userinfo) (Attribute, bool), + authFunc func(r *http.Request, userinfo *url.Userinfo, t *Token) (Attribute, bool), logger *slog.Logger, ) *Generator { return &Generator{ @@ -63,6 +64,14 @@ func (g *Generator) ServeHTTP(rw http.ResponseWriter, r *http.Request) { }) } +func getIP(str string) string { + host, _, err := net.SplitHostPort(str) + if err == nil && host != "" { + return host + } + return str +} + func (g *Generator) getToken(r *http.Request) (*Token, error) { query := r.URL.Query() account := query.Get("account") @@ -73,6 +82,7 @@ func (g *Generator) getToken(r *http.Request) (*Token, error) { Service: service, Scope: scope, Account: account, + IP: getIP(r.RemoteAddr), } if scope != "" { @@ -95,7 +105,7 @@ func (g *Generator) getToken(r *http.Request) (*Token, error) { authorization := r.Header.Get("Authorization") if authorization == "" { - attribute, login := g.authFunc(r, nil) + attribute, login := g.authFunc(r, nil, &t) if !login { return nil, errcode.ErrorCodeDenied } @@ -125,7 +135,7 @@ func (g *Generator) getToken(r *http.Request) (*Token, error) { u = url.User(user) } - attribute, login := g.authFunc(r, u) + attribute, login := g.authFunc(r, u, &t) if !login { g.logger.Error("Login failed user and password", "user", u.Username()) return nil, errcode.ErrorCodeDenied