From cd897d8c6a40e743921a927dffc44def485c8755 Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Thu, 16 Jan 2025 18:04:16 +0800 Subject: [PATCH] Add queue --- cmd/crproxy/cluster/auth/auth.go | 2 +- cmd/crproxy/cluster/cluster.go | 5 + cmd/crproxy/cluster/queue/queue.go | 114 +++++ cmd/crproxy/cluster/runner/runner.go | 144 +++++++ go.mod | 1 + go.sum | 2 + manager/manager.go | 2 +- queue/client/client.go | 400 ++++++++++++++++++ queue/client/error.go | 6 + queue/controller/error.go | 6 + queue/controller/message.go | 597 +++++++++++++++++++++++++++ queue/dao/db.go | 37 ++ queue/dao/messgae.go | 295 +++++++++++++ queue/model/message.go | 49 +++ queue/model/utils.go | 45 ++ queue/queue.go | 99 +++++ queue/service/message.go | 147 +++++++ runner/runner.go | 243 +++++++++++ sync/sync.go | 22 +- 19 files changed, 2212 insertions(+), 4 deletions(-) create mode 100644 cmd/crproxy/cluster/queue/queue.go create mode 100644 cmd/crproxy/cluster/runner/runner.go create mode 100644 queue/client/client.go create mode 100644 queue/client/error.go create mode 100644 queue/controller/error.go create mode 100644 queue/controller/message.go create mode 100644 queue/dao/db.go create mode 100644 queue/dao/messgae.go create mode 100644 queue/model/message.go create mode 100644 queue/model/utils.go create mode 100644 queue/queue.go create mode 100644 queue/service/message.go create mode 100644 runner/runner.go diff --git a/cmd/crproxy/cluster/auth/auth.go b/cmd/crproxy/cluster/auth/auth.go index 484b6a4..03c03da 100644 --- a/cmd/crproxy/cluster/auth/auth.go +++ b/cmd/crproxy/cluster/auth/auth.go @@ -211,7 +211,7 @@ func runE(ctx context.Context, flags *flagpole) error { } handler = handlers.CORS( - handlers.AllowedMethods([]string{http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}), + handlers.AllowedMethods([]string{http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete}), handlers.AllowedHeaders([]string{"Authorization", "Accept", "Content-Type", "Origin"}), handlers.AllowedOrigins([]string{"*"}), )(handler) diff --git a/cmd/crproxy/cluster/cluster.go b/cmd/crproxy/cluster/cluster.go index daeab36..5879241 100644 --- a/cmd/crproxy/cluster/cluster.go +++ b/cmd/crproxy/cluster/cluster.go @@ -6,6 +6,8 @@ import ( "github.com/daocloud/crproxy/cmd/crproxy/cluster/agent" "github.com/daocloud/crproxy/cmd/crproxy/cluster/auth" "github.com/daocloud/crproxy/cmd/crproxy/cluster/gateway" + "github.com/daocloud/crproxy/cmd/crproxy/cluster/queue" + "github.com/daocloud/crproxy/cmd/crproxy/cluster/runner" ) func NewCommand() *cobra.Command { @@ -20,5 +22,8 @@ func NewCommand() *cobra.Command { cmd.AddCommand(agent.NewCommand()) cmd.AddCommand(gateway.NewCommand()) cmd.AddCommand(auth.NewCommand()) + + cmd.AddCommand(queue.NewCommand()) + cmd.AddCommand(runner.NewCommand()) return cmd } diff --git a/cmd/crproxy/cluster/queue/queue.go b/cmd/crproxy/cluster/queue/queue.go new file mode 100644 index 0000000..cc3d265 --- /dev/null +++ b/cmd/crproxy/cluster/queue/queue.go @@ -0,0 +1,114 @@ +package queue + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "net/http" + "os" + + _ "github.com/go-sql-driver/mysql" + + "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/queue" + "github.com/emicklei/go-restful/v3" + "github.com/gorilla/handlers" + "github.com/spf13/cobra" +) + +type flagpole struct { + Behind bool + Address string + AcmeHosts []string + AcmeCacheDir string + CertFile string + PrivateKeyFile string + + TokenPublicKeyFile string + + SimpleAuthUserpass map[string]string + + AdminToken string + + DBURL string +} + +func NewCommand() *cobra.Command { + flags := &flagpole{ + Address: ":18010", + } + + cmd := &cobra.Command{ + Use: "queue", + Short: "Queue", + RunE: func(cmd *cobra.Command, args []string) error { + return runE(cmd.Context(), flags) + }, + } + + cmd.Flags().BoolVar(&flags.Behind, "behind", flags.Behind, "Behind") + cmd.Flags().StringVar(&flags.Address, "address", flags.Address, "Address") + cmd.Flags().StringSliceVar(&flags.AcmeHosts, "acme-hosts", flags.AcmeHosts, "Acme hosts") + cmd.Flags().StringVar(&flags.AcmeCacheDir, "acme-cache-dir", flags.AcmeCacheDir, "Acme cache dir") + cmd.Flags().StringVar(&flags.CertFile, "cert-file", flags.CertFile, "Cert file") + cmd.Flags().StringVar(&flags.PrivateKeyFile, "private-key-file", flags.PrivateKeyFile, "Private key file") + + cmd.Flags().StringVar(&flags.TokenPublicKeyFile, "token-public-key-file", "", "public key file") + + cmd.Flags().StringVar(&flags.AdminToken, "admin-token", flags.AdminToken, "Admin token") + + cmd.Flags().StringVar(&flags.DBURL, "db-url", flags.DBURL, "Database URL") + + return cmd +} + +func runE(ctx context.Context, flags *flagpole) error { + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + + container := restful.NewContainer() + + var mgr *queue.QueueManager + if flags.DBURL != "" { + dburl := flags.DBURL + db, err := sql.Open("mysql", dburl) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + defer db.Close() + + if err = db.Ping(); err != nil { + return fmt.Errorf("failed to ping database: %w", err) + } + + logger.Info("Connected to DB") + + mgr = queue.NewQueueManager(flags.AdminToken, db) + + mgr.Register(container) + + mgr.InitTable(ctx) + + mgr.Schedule(ctx, logger) + } + + var handler http.Handler = container + + handler = handlers.LoggingHandler(os.Stderr, handler) + + if flags.Behind { + handler = handlers.ProxyHeaders(handler) + } + + handler = handlers.CORS( + handlers.AllowedMethods([]string{http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete}), + handlers.AllowedHeaders([]string{"Authorization", "Accept", "Content-Type", "Origin"}), + handlers.AllowedOrigins([]string{"*"}), + )(handler) + + err := server.Run(ctx, flags.Address, handler, flags.AcmeHosts, flags.AcmeCacheDir, flags.CertFile, flags.PrivateKeyFile) + if err != nil { + return fmt.Errorf("failed to run server: %w", err) + } + return nil +} diff --git a/cmd/crproxy/cluster/runner/runner.go b/cmd/crproxy/cluster/runner/runner.go new file mode 100644 index 0000000..c7a1f6e --- /dev/null +++ b/cmd/crproxy/cluster/runner/runner.go @@ -0,0 +1,144 @@ +package runner + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/daocloud/crproxy/cache" + "github.com/daocloud/crproxy/runner" + "github.com/daocloud/crproxy/storage" + csync "github.com/daocloud/crproxy/sync" + "github.com/daocloud/crproxy/transport" + "github.com/docker/distribution/manifest/manifestlist" + "github.com/spf13/cobra" +) + +type flagpole struct { + QueueURL string + + AdminToken string + + StorageURL []string + Deep bool + Quick bool + Platform []string + Userpass []string + + Duration time.Duration +} + +func NewCommand() *cobra.Command { + flags := &flagpole{ + Platform: []string{ + "linux/amd64", + "linux/arm64", + }, + } + + cmd := &cobra.Command{ + Use: "runner", + Short: "Runner", + RunE: func(cmd *cobra.Command, args []string) error { + return runE(cmd.Context(), flags) + }, + } + + cmd.Flags().StringVar(&flags.AdminToken, "admin-token", flags.AdminToken, "Admin token") + + cmd.Flags().StringVar(&flags.QueueURL, "queue-url", flags.QueueURL, "Queue URL") + + cmd.Flags().StringArrayVar(&flags.StorageURL, "storage-url", flags.StorageURL, "Storage driver url") + cmd.Flags().BoolVar(&flags.Deep, "deep", flags.Deep, "Deep sync with blob") + cmd.Flags().BoolVar(&flags.Quick, "quick", flags.Quick, "Quick sync with tags") + cmd.Flags().StringSliceVar(&flags.Platform, "platform", flags.Platform, "Platform") + cmd.Flags().StringArrayVarP(&flags.Userpass, "user", "u", flags.Userpass, "host and username and password -u user:pwd@host") + + cmd.Flags().DurationVar(&flags.Duration, "duration", flags.Duration, "Duration of the running") + + return cmd +} + +func runE(ctx context.Context, flags *flagpole) error { + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + + opts := []csync.Option{} + + var caches []*cache.Cache + for _, s := range flags.StorageURL { + sd, err := storage.NewStorage(s) + if err != nil { + return fmt.Errorf("create storage driver failed: %w", err) + } + + cache, err := cache.NewCache(cache.WithStorageDriver(sd)) + if err != nil { + return fmt.Errorf("create cache failed: %w", err) + } + + caches = append(caches, cache) + } + + transportOpts := []transport.Option{ + transport.WithLogger(logger), + } + + if len(flags.Userpass) != 0 { + transportOpts = append(transportOpts, transport.WithUserAndPass(flags.Userpass)) + } + + tp, err := transport.NewTransport(transportOpts...) + if err != nil { + return fmt.Errorf("create transport failed: %w", err) + } + + opts = append(opts, + csync.WithCaches(caches...), + csync.WithDeep(flags.Deep), + csync.WithQuick(flags.Quick), + csync.WithTransport(tp), + csync.WithLogger(logger), + csync.WithFilterPlatform(filterPlatform(flags.Platform)), + ) + + sm, err := csync.NewSyncManager(opts...) + if err != nil { + return fmt.Errorf("create sync manager failed: %w", err) + } + + runner, err := runner.NewRunner(http.DefaultClient, flags.QueueURL, flags.AdminToken, sm) + if err != nil { + return err + } + + if flags.Duration > 0 { + ctx, _ = context.WithTimeout(ctx, flags.Duration) + } + + err = runner.Run(ctx, logger) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + return err + } + } + return nil +} + +func filterPlatform(ps []string) func(pf manifestlist.PlatformSpec) bool { + platforms := map[string]struct{}{} + for _, p := range ps { + platforms[p] = struct{}{} + } + return func(pf manifestlist.PlatformSpec) bool { + p := fmt.Sprintf("%s/%s", pf.OS, pf.Architecture) + + if _, ok := platforms[p]; ok { + return true + } + return false + } +} diff --git a/go.mod b/go.mod index 9ac6e52..b8d1ff9 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-openapi/spec v0.20.9 github.com/go-sql-driver/mysql v1.8.1 github.com/google/go-containerregistry v0.20.2 + github.com/google/uuid v1.6.0 github.com/gorilla/handlers v1.5.2 github.com/huaweicloud/huaweicloud-sdk-go-obs v3.24.6+incompatible github.com/opencontainers/go-digest v1.0.0 diff --git a/go.sum b/go.sum index 31a1a46..2f2807f 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= diff --git a/manager/manager.go b/manager/manager.go index c8522e4..91dd8ad 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -91,7 +91,7 @@ func (m *Manager) Register(container *restful.Container) { PostBuildSwaggerObjectHandler: func(s *spec.Swagger) { s.Info = &spec.Info{} s.Info.Title = "CRProxy Manager" - s.Schemes = []string{"http", "https"} + s.Schemes = []string{"https", "http"} s.SecurityDefinitions = spec.SecurityDefinitions{ "BearerHeader": { SecuritySchemeProps: spec.SecuritySchemeProps{ diff --git a/queue/client/client.go b/queue/client/client.go new file mode 100644 index 0000000..09fb697 --- /dev/null +++ b/queue/client/client.go @@ -0,0 +1,400 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/daocloud/crproxy/queue/model" +) + +type MessageRequest struct { + Content string `json:"content"` + Priority int `json:"priority"` +} + +type MessageResponse struct { + MessageID int64 `json:"id"` + Content string `json:"content"` + Priority int `json:"priority"` + Status model.MessageStatus `json:"status"` + Data model.MessageAttr `json:"data,omitempty"` + LastHeartbeat time.Time `json:"last_heartbeat"` +} + +type ConsumeRequest struct { + Lease string `json:"lease"` +} + +type HeartbeatRequest struct { + Data model.MessageAttr `json:"data"` + Lease string `json:"lease"` +} + +type CompletedRequest struct { + Lease string `json:"lease"` +} + +type FailedRequest struct { + Lease string `json:"lease"` + Data model.MessageAttr `json:"data"` +} + +type CancelRequest struct { + Lease string `json:"lease"` +} + +type MessageClient struct { + httpClient *http.Client + baseURL string + adminToken string +} + +func NewMessageClient(httpClient *http.Client, baseURL string, adminToken string) *MessageClient { + return &MessageClient{ + httpClient: httpClient, + baseURL: baseURL, + adminToken: adminToken, + } +} + +func (c *MessageClient) Create(ctx context.Context, content string, priority int) (MessageResponse, error) { + messageRequest := MessageRequest{Content: content, Priority: priority} + body, err := json.Marshal(messageRequest) + if err != nil { + return MessageResponse{}, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.baseURL+"/messages", bytes.NewBuffer(body)) + if err != nil { + return MessageResponse{}, err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return MessageResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return MessageResponse{}, handleErrorResponse(resp) + } + + var messageResponse MessageResponse + if err := json.NewDecoder(resp.Body).Decode(&messageResponse); err != nil { + return MessageResponse{}, err + } + + return messageResponse, nil +} + +func (c *MessageClient) List(ctx context.Context) ([]MessageResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/messages", nil) + if err != nil { + return nil, err + } + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(resp) + } + + var messages []MessageResponse + if err := json.NewDecoder(resp.Body).Decode(&messages); err != nil { + return nil, err + } + + return messages, nil +} + +func (c *MessageClient) WatchList(ctx context.Context) (chan MessageResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/messages?watch=1", nil) + if err != nil { + return nil, err + } + req.Header.Set("watch", "1") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, handleErrorResponse(resp) + } + + if resp.Header.Get("Content-Type") != "text/event-stream" { + defer resp.Body.Close() + return nil, handleErrorResponse(resp) + } + + messageChannel := make(chan MessageResponse) + + go func() { + defer resp.Body.Close() + defer close(messageChannel) + decoder := json.NewDecoder(resp.Body) + for { + var message MessageResponse + err := decoder.Decode(&message) + if err != nil { + return + } + + messageChannel <- message + } + }() + + return messageChannel, nil +} + +func (c *MessageClient) Get(ctx context.Context, messageID int64) (MessageResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10), nil) + if err != nil { + return MessageResponse{}, err + } + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return MessageResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return MessageResponse{}, handleErrorResponse(resp) + } + + var messageResponse MessageResponse + if err := json.NewDecoder(resp.Body).Decode(&messageResponse); err != nil { + return MessageResponse{}, err + } + + return messageResponse, nil +} + +func (c *MessageClient) Watch(ctx context.Context, messageID int64) (chan MessageResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"?watch=1", nil) + if err != nil { + return nil, err + } + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, handleErrorResponse(resp) + } + + if resp.Header.Get("Content-Type") != "text/event-stream" { + defer resp.Body.Close() + return nil, handleErrorResponse(resp) + } + + messageChannel := make(chan MessageResponse) + + go func() { + defer resp.Body.Close() + defer close(messageChannel) + decoder := json.NewDecoder(resp.Body) + for { + var message MessageResponse + err := decoder.Decode(&message) + if err != nil { + return + } + messageChannel <- message + } + }() + + return messageChannel, nil +} + +func (c *MessageClient) Consume(ctx context.Context, messageID int64, lease string) (MessageResponse, error) { + completedRequest := CompletedRequest{Lease: lease} + body, err := json.Marshal(completedRequest) + if err != nil { + return MessageResponse{}, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"/consume", bytes.NewBuffer(body)) + if err != nil { + return MessageResponse{}, err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return MessageResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return MessageResponse{}, handleErrorResponse(resp) + } + + var messageResponse MessageResponse + if err := json.NewDecoder(resp.Body).Decode(&messageResponse); err != nil { + return MessageResponse{}, err + } + + return messageResponse, nil +} + +func (c *MessageClient) Heartbeat(ctx context.Context, messageID int64, heartbeatRequest HeartbeatRequest) error { + body, err := json.Marshal(heartbeatRequest) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"/heartbeat", bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return handleErrorResponse(resp) + } + + return nil +} + +func (c *MessageClient) Complete(ctx context.Context, messageID int64, completedRequest CompletedRequest) error { + body, err := json.Marshal(completedRequest) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"/complete", bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return handleErrorResponse(resp) + } + + return nil +} + +func (c *MessageClient) Failed(ctx context.Context, messageID int64, failedRequest FailedRequest) error { + body, err := json.Marshal(failedRequest) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"/failed", bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return handleErrorResponse(resp) + } + + return nil +} + +func (c *MessageClient) Cancel(ctx context.Context, messageID int64, cancelRequest CancelRequest) error { + body, err := json.Marshal(cancelRequest) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.baseURL+"/messages/"+strconv.FormatInt(messageID, 10)+"/cancel", bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + if c.adminToken != "" { + req.Header.Set("Authorization", "Bearer "+c.adminToken) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return handleErrorResponse(resp) + } + + return nil +} + +func handleErrorResponse(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + var errResponse Error + + err = json.Unmarshal(body, &errResponse) + if err != nil { + return fmt.Errorf("Error %s", body) + } + return fmt.Errorf("error: %s, message: %s", errResponse.Code, errResponse.Message) +} diff --git a/queue/client/error.go b/queue/client/error.go new file mode 100644 index 0000000..e7950ec --- /dev/null +++ b/queue/client/error.go @@ -0,0 +1,6 @@ +package client + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` +} diff --git a/queue/controller/error.go b/queue/controller/error.go new file mode 100644 index 0000000..daf5469 --- /dev/null +++ b/queue/controller/error.go @@ -0,0 +1,6 @@ +package controller + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` +} diff --git a/queue/controller/message.go b/queue/controller/message.go new file mode 100644 index 0000000..e62b694 --- /dev/null +++ b/queue/controller/message.go @@ -0,0 +1,597 @@ +package controller + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "log/slog" + "net/http" + "strconv" + "sync" + "time" + + "github.com/daocloud/crproxy/queue/model" + "github.com/daocloud/crproxy/queue/service" + "github.com/emicklei/go-restful/v3" +) + +type MessageRequest struct { + Content string `json:"content"` + Priority int `json:"priority"` +} + +type MessageResponse struct { + MessageID int64 `json:"id"` + Content string `json:"content"` + Priority int `json:"priority"` + Status model.MessageStatus `json:"status"` + Data model.MessageAttr `json:"data,omitempty"` + LastHeartbeat time.Time `json:"last_heartbeat"` +} + +type ConsumeRequest struct { + Lease string `json:"lease"` +} + +type HeartbeatRequest struct { + Data model.MessageAttr `json:"data"` + Lease string `json:"lease"` +} + +type CompletedRequest struct { + Lease string `json:"lease"` +} + +type FailedRequest struct { + Lease string `json:"lease"` + Data model.MessageAttr `json:"data"` +} + +type CancelRequest struct { + Lease string `json:"lease"` +} + +type MessageController struct { + messageService *service.MessageService + + watchChannelsMut sync.Mutex + watchChannels map[int64][]chan MessageResponse + + watchListChannelsMut sync.Mutex + watchListChannels []chan MessageResponse +} + +func (mc *MessageController) newWatchChannel(messageID int64) chan MessageResponse { + ch := make(chan MessageResponse, 1) + mc.watchChannelsMut.Lock() + defer mc.watchChannelsMut.Unlock() + mc.watchChannels[messageID] = append(mc.watchChannels[messageID], ch) + return ch +} + +func (mc *MessageController) cancelWatchChannel(messageID int64, ch chan MessageResponse) { + mc.watchChannelsMut.Lock() + defer mc.watchChannelsMut.Unlock() + channels := mc.watchChannels[messageID] + for i, channel := range channels { + if channel == ch { + mc.watchChannels[messageID] = append(channels[:i], channels[i+1:]...) + break + } + } + + if len(mc.watchChannels[messageID]) == 0 { + delete(mc.watchChannels, messageID) + } +} + +func (mc *MessageController) updateWatchChannel(messageID int64, mr MessageResponse) { + mc.watchChannelsMut.Lock() + defer mc.watchChannelsMut.Unlock() + for _, ch := range mc.watchChannels[messageID] { + select { + case ch <- mr: + default: + } + } +} + +func (mc *MessageController) newWatchListChannel() chan MessageResponse { + ch := make(chan MessageResponse, 1) + mc.watchListChannelsMut.Lock() + defer mc.watchListChannelsMut.Unlock() + mc.watchListChannels = append(mc.watchListChannels, ch) + return ch +} + +func (mc *MessageController) cancelWatchListChannel(ch chan MessageResponse) { + mc.watchListChannelsMut.Lock() + defer mc.watchListChannelsMut.Unlock() + for i, channel := range mc.watchListChannels { + if channel == ch { + mc.watchListChannels = append(mc.watchListChannels[:i], mc.watchListChannels[i+1:]...) + break + } + } +} + +func (mc *MessageController) updateWatchListChannels(mr MessageResponse) { + mc.watchListChannelsMut.Lock() + defer mc.watchListChannelsMut.Unlock() + for _, ch := range mc.watchListChannels { + select { + case ch <- mr: + default: + } + } +} + +func NewMessageController(messageService *service.MessageService) *MessageController { + return &MessageController{messageService: messageService, watchChannels: map[int64][]chan MessageResponse{}} +} + +func (mc *MessageController) RegisterRoutes(ws *restful.WebService) { + ws.Route(ws.PUT("/messages").To(mc.Create). + Doc("Try create a new message."). + Operation("createMessage"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Reads(MessageRequest{}). + Writes(MessageResponse{}). + Returns(http.StatusCreated, "Message created successfully.", MessageResponse{}). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) + + ws.Route(ws.GET("/messages").To(mc.List). + Doc("List all messages."). + Operation("listMessages"). + Param(ws.QueryParameter("watch", "Watch the message for updates").DataType("boolean")). + Produces(restful.MIME_JSON). + Writes([]MessageResponse{}). + Returns(http.StatusOK, "Messages retrieved successfully.", []MessageResponse{}). + Returns(http.StatusInternalServerError, "Failed to retrieve messages.", Error{}). + Returns(http.StatusNoContent, "No messages available.", Error{})) + + ws.Route(ws.GET("/messages/{message_id}").To(mc.Get). + Doc("Retrieve a message by ID."). + Operation("getMessage"). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Param(ws.QueryParameter("watch", "Watch the message for updates").DataType("boolean")). + Produces(restful.MIME_JSON). + Writes(MessageResponse{}). + Returns(http.StatusOK, "Message found.", MessageResponse{}). + Returns(http.StatusNotFound, "Message not found.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) + + ws.Route(ws.POST("/messages/{message_id}/consume").To(mc.Consume). + Doc("Consume a message by ID."). + Operation("consume"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Reads(CompletedRequest{}). + Writes(MessageResponse{}). + Returns(http.StatusOK, "Message consumed successfully.", MessageResponse{}). + Returns(http.StatusNotFound, "Message not found.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) + + ws.Route(ws.PATCH("/messages/{message_id}/heartbeat").To(mc.Heartbeat). + Doc("Set heartbeat for a message by ID."). + Operation("heartbeat"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Reads(HeartbeatRequest{}). + Writes(Error{}). + Returns(http.StatusNoContent, "Heartbeat updated successfully.", nil). + Returns(http.StatusNotFound, "Message not found.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) + + ws.Route(ws.PATCH("/messages/{message_id}/complete").To(mc.Completed). + Doc("Set a message as completed by ID."). + Operation("setCompleted"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Reads(CompletedRequest{}). + Writes(Error{}). + Returns(http.StatusNoContent, "Message failed successfully.", nil). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) + + ws.Route(ws.PATCH("/messages/{message_id}/failed").To(mc.Failed). + Doc("Set a message as failed by ID."). + Operation("setFailed"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Reads(FailedRequest{}). + Writes(Error{}). + Returns(http.StatusNoContent, "Message failed successfully.", nil). + Returns(http.StatusNotFound, "Message not found.", Error{})) + + ws.Route(ws.PATCH("/messages/{message_id}/cancel").To(mc.Cancel). + Doc("Cancel a message by ID."). + Operation("cancel"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.PathParameter("message_id", "message ID").DataType("integer")). + Reads(CancelRequest{}). + Writes(Error{}). + Returns(http.StatusNoContent, "Message canceled successfully.", nil). + Returns(http.StatusNotFound, "Message not found.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format.", Error{})) +} + +func (mc *MessageController) Schedule(ctx context.Context, logger *slog.Logger) { + ticker := time.NewTicker(1 * time.Minute) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + staleList, err := mc.messageService.GetStale(ctx, time.Now().Add(-time.Minute)) + if err != nil { + logger.Error("ReleaseStale", "error", err) + } else { + for _, item := range staleList { + err := mc.messageService.ResetToPending(ctx, item.MessageID) + if err != nil { + logger.Error("ResetToPending", "error", err) + } else { + data := MessageResponse{MessageID: item.MessageID, Content: item.Content, Priority: item.Priority, Status: item.Status, Data: item.Data, LastHeartbeat: item.LastHeartbeat} + mc.updateWatchListChannels(data) + mc.updateWatchChannel(item.MessageID, data) + } + } + } + + cleanList, err := mc.messageService.GetCompletedAndFailed(ctx, time.Now().Add(-time.Hour)) + if err != nil { + logger.Error("DeleteCompletedAndFailed", "error", err) + } else { + for _, item := range cleanList { + err := mc.messageService.DeleteByID(ctx, item.MessageID) + if err != nil { + logger.Error("DeleteByID", "error", err) + } + } + } + } + } +} + +func (mc *MessageController) Create(req *restful.Request, resp *restful.Response) { + var messageRequest MessageRequest + if err := req.ReadEntity(&messageRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "MessageRequestError", Message: "Failed to read message request: " + err.Error()}) + return + } + + message, err := mc.messageService.GetByContent(req.Request.Context(), messageRequest.Content) + if err == nil { + data := MessageResponse{ + MessageID: message.MessageID, + Content: message.Content, + Priority: message.Priority, + Status: message.Status, + Data: message.Data, + LastHeartbeat: message.LastHeartbeat, + } + + if message.Status == model.StatusPending && messageRequest.Priority > message.Priority { + if err := mc.messageService.UpdatePriorityByID(req.Request.Context(), message.MessageID, messageRequest.Priority); err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "MessageUpdateError", Message: "Failed to update message priority: " + err.Error()}) + return + } + data.Priority = messageRequest.Priority + } + + mc.updateWatchListChannels(data) + resp.WriteHeaderAndEntity(http.StatusOK, data) + return + } + + if !errors.Is(err, sql.ErrNoRows) { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "MessageRetrievalError", Message: "Failed to retrieve message: " + err.Error()}) + return + } + + newMessage := model.Message{ + Content: messageRequest.Content, + Priority: messageRequest.Priority, + } + messageID, err := mc.messageService.Create(req.Request.Context(), newMessage) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "MessageCreationError", Message: "Failed to create message: " + err.Error()}) + return + } + + data := MessageResponse{ + MessageID: messageID, + Content: messageRequest.Content, + Priority: messageRequest.Priority, + } + + mc.updateWatchListChannels(data) + resp.WriteHeaderAndEntity(http.StatusCreated, data) +} + +func (mc *MessageController) List(req *restful.Request, resp *restful.Response) { + messages, err := mc.messageService.List(req.Request.Context()) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "MessageListError", Message: "Failed to retrieve messages: " + err.Error()}) + return + } + var messageResponses = make([]MessageResponse, 0, len(messages)) + for _, message := range messages { + messageResponses = append(messageResponses, MessageResponse{ + MessageID: message.MessageID, + Content: message.Content, + Priority: message.Priority, + Status: message.Status, + Data: message.Data, + LastHeartbeat: message.LastHeartbeat, + }) + } + + watch, _ := strconv.ParseBool(req.QueryParameter("watch")) + if !watch { + resp.WriteHeaderAndEntity(http.StatusOK, messageResponses) + return + } + + resp.Header().Set("Transfer-Encoding", "chunked") + resp.Header().Set("X-Accel-Buffering", "no") + resp.Header().Set("Content-Type", "text/event-stream") + resp.Header().Set("Cache-Control", "no-cache") + resp.Header().Set("Connection", "keep-alive") + resp.WriteHeader(http.StatusOK) + + watchCh := mc.newWatchListChannel() + defer mc.cancelWatchListChannel(watchCh) + + encoder := json.NewEncoder(resp.ResponseWriter) + + for _, d := range messageResponses { + encoder.Encode(d) + } + resp.Flush() + + messageResponses = nil + + ctx := req.Request.Context() + for { + select { + case <-ctx.Done(): + return + case data, ok := <-watchCh: + if !ok { + return + } + + encoder.Encode(data) + resp.Flush() + } + } +} + +func (mc *MessageController) Get(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + message, err := mc.messageService.GetByID(req.Request.Context(), messageID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "MessageNotFoundError", Message: "Message not found: " + err.Error()}) + return + } + + if message.Status != model.StatusProcessing && message.Status != model.StatusPending { + resp.WriteHeaderAndEntity(http.StatusOK, MessageResponse{MessageID: message.MessageID, Content: message.Content, Priority: message.Priority, Status: message.Status, Data: message.Data, LastHeartbeat: message.LastHeartbeat}) + return + } + + watch, _ := strconv.ParseBool(req.QueryParameter("watch")) + if !watch { + resp.WriteHeaderAndEntity(http.StatusOK, MessageResponse{MessageID: message.MessageID, Content: message.Content, Priority: message.Priority, Status: message.Status, Data: message.Data, LastHeartbeat: message.LastHeartbeat}) + return + } + + resp.Header().Set("Transfer-Encoding", "chunked") + resp.Header().Set("X-Accel-Buffering", "no") + resp.Header().Set("Content-Type", "text/event-stream") + resp.Header().Set("Cache-Control", "no-cache") + resp.Header().Set("Connection", "keep-alive") + resp.WriteHeader(http.StatusOK) + + watchCh := mc.newWatchChannel(messageID) + defer mc.cancelWatchChannel(messageID, watchCh) + + mc.updateWatchChannel(messageID, MessageResponse{MessageID: message.MessageID, Content: message.Content, Priority: message.Priority, Status: message.Status, Data: message.Data, LastHeartbeat: message.LastHeartbeat}) + + encoder := json.NewEncoder(resp.ResponseWriter) + + ctx := req.Request.Context() + for { + select { + case <-ctx.Done(): + return + case data, ok := <-watchCh: + if !ok { + return + } + + encoder.Encode(data) + resp.Flush() + + if data.Status != model.StatusProcessing && data.Status != model.StatusPending { + return + } + } + } +} + +func (mc *MessageController) Consume(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + var completedRequest CompletedRequest + if err := req.ReadEntity(&completedRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "CompletedRequestError", Message: "Failed to read completed request: " + err.Error()}) + return + } + + message, err := mc.messageService.Consume(req.Request.Context(), messageID, completedRequest.Lease) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotAcceptable, Error{Code: "MessageNotAcceptableError", Message: "Message not found: " + err.Error()}) + return + } + + data := MessageResponse{MessageID: message.MessageID, Content: message.Content, Priority: message.Priority, Status: message.Status, Data: message.Data, LastHeartbeat: message.LastHeartbeat} + mc.updateWatchChannel(messageID, data) + mc.updateWatchListChannels(data) + + resp.WriteHeaderAndEntity(http.StatusOK, data) +} + +func (mc *MessageController) Heartbeat(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + var heartbeatRequest HeartbeatRequest + if err := req.ReadEntity(&heartbeatRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "HeartbeatRequestError", Message: "Failed to read heartbeat request: " + err.Error()}) + return + } + + now := time.Now() + + if err := mc.messageService.Heartbeat(req.Request.Context(), messageID, now, heartbeatRequest.Data, heartbeatRequest.Lease); err != nil { + resp.WriteHeaderAndEntity(http.StatusNotAcceptable, Error{Code: "MessageNotAcceptableError", Message: "Message not found: " + err.Error()}) + return + } + + curr, err := mc.messageService.GetByID(context.Background(), messageID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "MessageNotFoundError", Message: "Message not found after heartbeat: " + err.Error()}) + return + } + + data := MessageResponse{MessageID: curr.MessageID, Content: curr.Content, Priority: curr.Priority, Status: curr.Status, Data: curr.Data, LastHeartbeat: curr.LastHeartbeat} + + mc.updateWatchChannel(messageID, data) + mc.updateWatchListChannels(data) + + resp.WriteHeader(http.StatusNoContent) +} + +func (mc *MessageController) Completed(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + var completedRequest CompletedRequest + if err := req.ReadEntity(&completedRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "CompletedRequestError", Message: "Failed to read completed request: " + err.Error()}) + return + } + + if err := mc.messageService.Completed(req.Request.Context(), messageID, completedRequest.Lease); err != nil { + resp.WriteHeaderAndEntity(http.StatusNotAcceptable, Error{Code: "MessageNotAcceptabledError", Message: "Message not found: " + err.Error()}) + return + } + + curr, err := mc.messageService.GetByID(context.Background(), messageID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "MessageNotFoundError", Message: "Message not found after completion: " + err.Error()}) + return + } + + data := MessageResponse{MessageID: curr.MessageID, Content: curr.Content, Priority: curr.Priority, Status: curr.Status, Data: curr.Data, LastHeartbeat: curr.LastHeartbeat} + + mc.updateWatchChannel(messageID, data) + mc.updateWatchListChannels(data) + + resp.WriteHeader(http.StatusNoContent) +} + +func (mc *MessageController) Failed(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + var failedRequest FailedRequest + if err := req.ReadEntity(&failedRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "FailedRequestError", Message: "Failed to read failed request: " + err.Error()}) + return + } + + if err := mc.messageService.Failed(req.Request.Context(), messageID, failedRequest.Lease, failedRequest.Data); err != nil { + resp.WriteHeaderAndEntity(http.StatusNotAcceptable, Error{Code: "MessageNotAcceptableError", Message: "Message not found: " + err.Error()}) + return + } + + curr, err := mc.messageService.GetByID(context.Background(), messageID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "MessageNotFoundError", Message: "Message not found after failure: " + err.Error()}) + return + } + + data := MessageResponse{MessageID: curr.MessageID, Content: curr.Content, Priority: curr.Priority, Status: curr.Status, Data: curr.Data, LastHeartbeat: curr.LastHeartbeat} + + mc.updateWatchChannel(messageID, data) + mc.updateWatchListChannels(data) + + resp.WriteHeader(http.StatusNoContent) +} + +func (mc *MessageController) Cancel(req *restful.Request, resp *restful.Response) { + messageIDStr := req.PathParameter("message_id") + messageID, err := strconv.ParseInt(messageIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidIDError", Message: "Invalid message ID: " + err.Error()}) + return + } + + var cancelRequest CancelRequest + if err := req.ReadEntity(&cancelRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "FailedRequestError", Message: "Failed to read failed request: " + err.Error()}) + return + } + + if err := mc.messageService.Cancel(req.Request.Context(), messageID, cancelRequest.Lease); err != nil { + resp.WriteHeaderAndEntity(http.StatusNotAcceptable, Error{Code: "MessageNotAcceptableError", Message: "Message not found: " + err.Error()}) + return + } + + curr, err := mc.messageService.GetByID(context.Background(), messageID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "MessageNotFoundError", Message: "Message not found after failure: " + err.Error()}) + return + } + + data := MessageResponse{MessageID: curr.MessageID, Content: curr.Content, Priority: curr.Priority, Status: curr.Status, Data: curr.Data, LastHeartbeat: curr.LastHeartbeat} + + mc.updateWatchChannel(messageID, data) + mc.updateWatchListChannels(data) + + resp.WriteHeader(http.StatusNoContent) +} diff --git a/queue/dao/db.go b/queue/dao/db.go new file mode 100644 index 0000000..4386622 --- /dev/null +++ b/queue/dao/db.go @@ -0,0 +1,37 @@ +package dao + +import ( + "context" + "database/sql" +) + +type dbCtxKey struct{} + +// contextKey is a key type for storing database context values. +var contextKey = dbCtxKey{} + +// WithDB returns a new context with the given database connection. +func WithDB(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, contextKey, db) +} + +// GetDB retrieves the database connection from the context. +func GetDB(ctx context.Context) DB { + db := ctx.Value(contextKey) + if db == nil { + return nil + } + d, _ := db.(DB) + return d +} + +var ( + _ DB = (*sql.Tx)(nil) + _ DB = (*sql.DB)(nil) +) + +type DB interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} diff --git a/queue/dao/messgae.go b/queue/dao/messgae.go new file mode 100644 index 0000000..b5d3593 --- /dev/null +++ b/queue/dao/messgae.go @@ -0,0 +1,295 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/daocloud/crproxy/queue/model" +) + +type Message struct{} + +func NewMessage() *Message { + return &Message{} +} + +const messageTableSQL = ` +CREATE TABLE IF NOT EXISTS messages ( + id SERIAL PRIMARY KEY, + content TEXT NOT NULL, + lease VARCHAR(36) NOT NULL, + priority INT DEFAULT 0, + status INT DEFAULT 0, + data JSON NOT NULL, + last_heartbeat TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + delete_at TIMESTAMP +) ENGINE=InnoDB AUTO_INCREMENT=10000 CHARSET=utf8mb4; +` + +func (l *Message) InitTable(ctx context.Context) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, messageTableSQL) + if err != nil { + return fmt.Errorf("failed to create messages table: %w", err) + } + return nil +} + +const createMessageSQL = ` +INSERT INTO messages (content, lease, priority, status, data) VALUES (?, ?, ?, ?, ?) +` + +func (m *Message) Create(ctx context.Context, message model.Message) (int64, error) { + db := GetDB(ctx) + result, err := db.ExecContext(ctx, createMessageSQL, message.Content, message.Lease, message.Priority, model.StatusPending, "{}") + if err != nil { + return 0, fmt.Errorf("failed to create message: %w", err) + } + return result.LastInsertId() +} + +const getMessageByContentSQL = ` +SELECT id, content, lease, priority, status, data, last_heartbeat FROM messages WHERE content = ? AND delete_at IS NULL +` + +func (m *Message) GetByContent(ctx context.Context, content string) (model.Message, error) { + db := GetDB(ctx) + var message model.Message + err := db.QueryRowContext(ctx, getMessageByContentSQL, content).Scan(&message.MessageID, &message.Content, &message.Lease, &message.Priority, &message.Status, &message.Data, &message.LastHeartbeat) + if err != nil { + if err == sql.ErrNoRows { + return model.Message{}, fmt.Errorf("message not found: %w", err) + } + return model.Message{}, fmt.Errorf("failed to get message: %w", err) + } + return message, nil +} + +const getMessageByIDSQL = ` +SELECT id, content, lease, priority, status, data, last_heartbeat FROM messages WHERE id = ? AND delete_at IS NULL +` + +func (m *Message) GetByID(ctx context.Context, id int64) (model.Message, error) { + db := GetDB(ctx) + var message model.Message + err := db.QueryRowContext(ctx, getMessageByIDSQL, id).Scan(&message.MessageID, &message.Content, &message.Lease, &message.Priority, &message.Status, &message.Data, &message.LastHeartbeat) + if err != nil { + if err == sql.ErrNoRows { + return model.Message{}, fmt.Errorf("message not found: %w", err) + } + return model.Message{}, fmt.Errorf("failed to get message: %w", err) + } + return message, nil +} + +const updateMessageByIDSQL = ` +UPDATE messages SET content = ?, lease = ?, priority = ?, status = ?, data = ? WHERE id = ? AND delete_at IS NULL +` + +func (m *Message) UpdateByID(ctx context.Context, id int64, message model.Message) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, updateMessageByIDSQL, message.Content, message.Lease, message.Priority, message.Status, message.Data, id) + if err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return nil +} + +const updateMessagePriorityByIDSQL = ` +UPDATE messages SET priority = ? WHERE id = ? AND priority < ? AND delete_at IS NULL +` + +func (m *Message) UpdatePriorityByID(ctx context.Context, id int64, priority int) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, updateMessagePriorityByIDSQL, priority, id, priority) + if err != nil { + return fmt.Errorf("failed to update message priority: %w", err) + } + return nil +} + +const deleteMessageByIDSQL = ` +UPDATE messages SET delete_at = NOW() WHERE id = ? AND delete_at IS NULL +` + +func (m *Message) DeleteByID(ctx context.Context, id int64) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, deleteMessageByIDSQL, id) + if err != nil { + return fmt.Errorf("failed to delete message: %w", err) + } + return nil +} + +const getMessagesSQL = ` +SELECT id, content, priority, status, data, last_heartbeat FROM messages WHERE delete_at IS NULL +` + +func (m *Message) List(ctx context.Context) ([]model.Message, error) { + db := GetDB(ctx) + rows, err := db.QueryContext(ctx, getMessagesSQL) + if err != nil { + return nil, fmt.Errorf("failed to list messages: %w", err) + } + defer rows.Close() + + var messages []model.Message + for rows.Next() { + var message model.Message + if err := rows.Scan(&message.MessageID, &message.Content, &message.Priority, &message.Status, &message.Data, &message.LastHeartbeat); err != nil { + return nil, fmt.Errorf("failed to scan message: %w", err) + } + messages = append(messages, message) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error occurred during rows iteration: %w", err) + } + + return messages, nil +} + +const getCompletedAndFailedMessagesSQL = ` +SELECT id, content, priority, status, data, last_heartbeat +FROM messages +WHERE (status = ? OR status = ?) +AND last_heartbeat < ? AND delete_at IS NULL +` + +func (m *Message) GetCompletedAndFailed(ctx context.Context, threshold time.Time) ([]model.Message, error) { + db := GetDB(ctx) + rows, err := db.QueryContext(ctx, getCompletedAndFailedMessagesSQL, model.StatusCompleted, model.StatusFailed, threshold) + if err != nil { + return nil, fmt.Errorf("failed to get completed and failed messages: %w", err) + } + defer rows.Close() + + var messages []model.Message + for rows.Next() { + var message model.Message + if err := rows.Scan(&message.MessageID, &message.Content, &message.Priority, &message.Status, &message.Data, &message.LastHeartbeat); err != nil { + return nil, fmt.Errorf("failed to scan message: %w", err) + } + messages = append(messages, message) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error occurred during rows iteration: %w", err) + } + + return messages, nil +} + +const getStaleMessagesSQL = ` +SELECT id, content, priority, status, data, last_heartbeat +FROM messages +WHERE status = ? +AND last_heartbeat < ? AND delete_at IS NULL +` + +func (m *Message) GetStale(ctx context.Context, threshold time.Time) ([]model.Message, error) { + db := GetDB(ctx) + rows, err := db.QueryContext(ctx, getStaleMessagesSQL, model.StatusProcessing, threshold) + if err != nil { + return nil, fmt.Errorf("failed to get stale messages: %w", err) + } + defer rows.Close() + + var messages []model.Message + for rows.Next() { + var message model.Message + if err := rows.Scan(&message.MessageID, &message.Content, &message.Priority, &message.Status, &message.Data, &message.LastHeartbeat); err != nil { + return nil, fmt.Errorf("failed to scan message: %w", err) + } + messages = append(messages, message) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error occurred during rows iteration: %w", err) + } + + return messages, nil +} + +const setStatusAndLeaseSQL = ` +UPDATE messages SET status = ?, lease = ? WHERE id = ? AND status = ? AND lease = ? AND delete_at IS NULL +` + +func (m *Message) SetStatusAndLease(ctx context.Context, id int64, status model.MessageStatus, lease string) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, setStatusAndLeaseSQL, status, lease, id, model.StatusPending, "") + if err != nil { + return 0, fmt.Errorf("failed to set status and lease: %w", err) + } + + return results.RowsAffected() +} + +const setHeartbeatAndDataSQL = ` +UPDATE messages SET last_heartbeat = ?, data = ? WHERE id = ? AND lease = ? AND status = ? AND delete_at IS NULL +` + +func (m *Message) SetHeartbeatAndData(ctx context.Context, id int64, lastHeartbeat time.Time, data model.MessageAttr, lease string) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, setHeartbeatAndDataSQL, lastHeartbeat, data, id, lease, model.StatusProcessing) + if err != nil { + return 0, fmt.Errorf("failed to set heartbeat and data: %w", err) + } + return results.RowsAffected() +} + +const setCompletedSQL = ` +UPDATE messages SET status = ?, lease = ? WHERE id = ? AND lease = ? AND status = ? AND delete_at IS NULL +` + +func (m *Message) SetCompleted(ctx context.Context, id int64, lease string) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, setCompletedSQL, model.StatusCompleted, "", id, lease, model.StatusProcessing) + if err != nil { + return 0, fmt.Errorf("failed to set completed: %w", err) + } + return results.RowsAffected() +} + +const setFailedSQL = ` +UPDATE messages SET status = ?, lease = ?, data = ? WHERE id = ? AND lease = ? AND status = ? AND delete_at IS NULL +` + +func (m *Message) SetFailed(ctx context.Context, id int64, lease string, data model.MessageAttr) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, setFailedSQL, model.StatusFailed, "", data, id, lease, model.StatusProcessing) + if err != nil { + return 0, fmt.Errorf("failed to set status, lease, and data: %w", err) + } + return results.RowsAffected() +} + +const cancelSQL = ` +UPDATE messages SET lease = ?, status = ? WHERE id = ? AND lease = ? AND status = ? AND delete_at IS NULL +` + +func (m *Message) Cancel(ctx context.Context, id int64, lease string) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, cancelSQL, "", model.StatusPending, id, lease, model.StatusProcessing) + if err != nil { + return 0, fmt.Errorf("failed to set status, lease, and data: %w", err) + } + return results.RowsAffected() +} + +const resetToPendingSQL = ` +UPDATE messages SET status = ?, lease = ? WHERE id = ? AND status = ? AND delete_at IS NULL +` + +func (m *Message) ResetToPending(ctx context.Context, id int64) (int64, error) { + db := GetDB(ctx) + results, err := db.ExecContext(ctx, resetToPendingSQL, model.StatusPending, "", id, model.StatusProcessing) + if err != nil { + return 0, fmt.Errorf("failed to reset status to pending: %w", err) + } + return results.RowsAffected() +} diff --git a/queue/model/message.go b/queue/model/message.go new file mode 100644 index 0000000..3b8c246 --- /dev/null +++ b/queue/model/message.go @@ -0,0 +1,49 @@ +package model + +import ( + "database/sql/driver" + "time" +) + +type MessageStatus uint64 + +const ( + StatusPending MessageStatus = 0 + StatusProcessing MessageStatus = 10 + StatusCompleted MessageStatus = 20 + StatusFailed MessageStatus = 30 +) + +type Message struct { + MessageID int64 + Content string + Lease string + Priority int + Status MessageStatus + Data MessageAttr + LastHeartbeat time.Time +} + +type MessageAttr struct { + Error string `json:"error,omitempty"` + Blobs []Blob `json:"blobs,omitempty"` +} + +type Blob struct { + Digest string `json:"digest"` + Progress int64 `json:"progress"` + Size int64 `json:"size"` + Error string `json:"error,omitempty"` +} + +func (n *MessageAttr) Scan(value any) error { + if value == nil { + return nil + } + *n = unmarshal[MessageAttr](asString(value)) + return nil +} + +func (n MessageAttr) Value() (driver.Value, error) { + return marshal(n), nil +} diff --git a/queue/model/utils.go b/queue/model/utils.go new file mode 100644 index 0000000..a4444b2 --- /dev/null +++ b/queue/model/utils.go @@ -0,0 +1,45 @@ +package model + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" +) + +func marshal[T any](t T) string { + d, err := json.Marshal(t) + if err != nil { + return fmt.Sprintf(`{"error":%q}`, err.Error()) + } + return string(d) +} + +func unmarshal[T any](s string) T { + var t T + json.Unmarshal([]byte(s), &t) + return t +} + +func asString(src any) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 0000000..c8411d5 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,99 @@ +package queue + +import ( + "context" + "database/sql" + "log/slog" + "net/http" + "strings" + + "github.com/daocloud/crproxy/queue/controller" + "github.com/daocloud/crproxy/queue/dao" + "github.com/daocloud/crproxy/queue/service" + restfulspec "github.com/emicklei/go-restful-openapi/v2" + "github.com/emicklei/go-restful/v3" + "github.com/go-openapi/spec" +) + +type QueueManager struct { + adminToken string + db *sql.DB + + MessageDAO *dao.Message + + MessageService *service.MessageService + + MessageController *controller.MessageController +} + +func NewQueueManager(adminToken string, db *sql.DB) *QueueManager { + m := &QueueManager{ + adminToken: adminToken, + db: db, + } + return m +} + +func (m *QueueManager) InitTable(ctx context.Context) { + ctx = dao.WithDB(ctx, m.db) + m.MessageDAO.InitTable(ctx) +} + +func (m *QueueManager) Register(container *restful.Container) { + m.MessageDAO = dao.NewMessage() + + m.MessageService = service.NewMessageService(m.db, m.MessageDAO) + m.MessageController = controller.NewMessageController(m.MessageService) + + ws := new(restful.WebService) + ws.Path("/apis/v1/") + + if m.adminToken != "" { + ws.Filter(func(req *restful.Request, resp *restful.Response, fc *restful.FilterChain) { + auth := req.HeaderParameter("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + resp.WriteErrorString(http.StatusForbidden, "Forbidden") + return + } + + if auth[7:] != m.adminToken { + resp.WriteErrorString(http.StatusForbidden, "Forbidden") + return + } + + fc.ProcessFilter(req, resp) + }) + } + m.MessageController.RegisterRoutes(ws) + + container.Add(ws) + + config := restfulspec.Config{ + WebServices: []*restful.WebService{ws}, + APIPath: "/swagger.json", + PostBuildSwaggerObjectHandler: func(s *spec.Swagger) { + s.Info = &spec.Info{} + s.Info.Title = "CRProxy Queue Manager" + s.Schemes = []string{"https", "http"} + s.SecurityDefinitions = spec.SecurityDefinitions{ + "BearerHeader": { + SecuritySchemeProps: spec.SecuritySchemeProps{ + Description: `Enter the token with the "Bearer token"`, + Type: "apiKey", + In: "header", + Name: "Authorization", + }, + }, + } + s.Security = []map[string][]string{ + {"BearerHeader": []string{}}, + } + }, + } + + container.Add(restfulspec.NewOpenAPIService(config)) +} + +func (m *QueueManager) Schedule(ctx context.Context, logger *slog.Logger) { + go m.MessageController.Schedule(ctx, logger) +} diff --git a/queue/service/message.go b/queue/service/message.go new file mode 100644 index 0000000..6e7487f --- /dev/null +++ b/queue/service/message.go @@ -0,0 +1,147 @@ +package service + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/daocloud/crproxy/queue/dao" + "github.com/daocloud/crproxy/queue/model" +) + +type MessageService struct { + db *sql.DB + messageDao *dao.Message +} + +func NewMessageService(db *sql.DB, messageDao *dao.Message) *MessageService { + return &MessageService{ + db: db, + messageDao: messageDao, + } +} + +func (s *MessageService) Create(ctx context.Context, message model.Message) (int64, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.Create(ctx, message) +} + +func (s *MessageService) GetByID(ctx context.Context, id int64) (model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.GetByID(ctx, id) +} + +func (s *MessageService) GetByContent(ctx context.Context, content string) (model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.GetByContent(ctx, content) +} + +func (s *MessageService) UpdateByID(ctx context.Context, id int64, message model.Message) error { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.UpdateByID(ctx, id, message) +} + +func (s *MessageService) UpdatePriorityByID(ctx context.Context, id int64, priority int) error { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.UpdatePriorityByID(ctx, id, priority) +} + +func (s *MessageService) DeleteByID(ctx context.Context, id int64) error { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.DeleteByID(ctx, id) +} + +func (s *MessageService) List(ctx context.Context) ([]model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.List(ctx) +} + +func (s *MessageService) Consume(ctx context.Context, id int64, lease string) (model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + + rowsAffected, err := s.messageDao.SetStatusAndLease(ctx, id, model.StatusProcessing, lease) + if err != nil { + return model.Message{}, err + } + + if rowsAffected == 0 { + return model.Message{}, fmt.Errorf("no rows affected when consuming message with id %d", id) + } + + message, err := s.messageDao.GetByID(ctx, id) + if err != nil { + return model.Message{}, err + } + return message, nil +} + +func (s *MessageService) Heartbeat(ctx context.Context, id int64, lastHeartbeat time.Time, data model.MessageAttr, lease string) error { + ctx = dao.WithDB(ctx, s.db) + rowsAffected, err := s.messageDao.SetHeartbeatAndData(ctx, id, lastHeartbeat, data, lease) + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("no rows affected when updating heartbeat for message with id %d", id) + } + return nil +} + +func (s *MessageService) Completed(ctx context.Context, id int64, lease string) error { + ctx = dao.WithDB(ctx, s.db) + rowsAffected, err := s.messageDao.SetCompleted(ctx, id, lease) + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("no rows affected when completing message with id %d", id) + } + return nil +} + +func (s *MessageService) Failed(ctx context.Context, id int64, lease string, data model.MessageAttr) error { + ctx = dao.WithDB(ctx, s.db) + rowsAffected, err := s.messageDao.SetFailed(ctx, id, lease, data) + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("no rows affected when failing message with id %d", id) + } + return nil +} + +func (s *MessageService) Cancel(ctx context.Context, id int64, lease string) error { + ctx = dao.WithDB(ctx, s.db) + rowsAffected, err := s.messageDao.Cancel(ctx, id, lease) + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("no rows affected when failing message with id %d", id) + } + return nil +} + +func (s *MessageService) GetCompletedAndFailed(ctx context.Context, threshold time.Time) ([]model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.GetCompletedAndFailed(ctx, threshold) +} + +func (s *MessageService) GetStale(ctx context.Context, threshold time.Time) ([]model.Message, error) { + ctx = dao.WithDB(ctx, s.db) + return s.messageDao.GetStale(ctx, threshold) +} + +func (s *MessageService) ResetToPending(ctx context.Context, id int64) error { + ctx = dao.WithDB(ctx, s.db) + rowsAffected, err := s.messageDao.ResetToPending(ctx, id) + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("no rows affected when resetting message with id %d to pending", id) + } + return nil +} diff --git a/runner/runner.go b/runner/runner.go new file mode 100644 index 0000000..c7d23fe --- /dev/null +++ b/runner/runner.go @@ -0,0 +1,243 @@ +package runner + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "sort" + "sync" + "time" + + "github.com/daocloud/crproxy/queue/client" + "github.com/daocloud/crproxy/queue/model" + csync "github.com/daocloud/crproxy/sync" +) + +type Runner struct { + client *client.MessageClient + syncManager *csync.SyncManager + lease string + pendingMut sync.Mutex + pending map[int64]client.MessageResponse + syncCh chan struct{} +} + +func identity() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", fmt.Errorf("unable to get hostname: %w", err) + } + hnHex := hex.EncodeToString([]byte(hostname)) + return fmt.Sprintf("%s-%d", hnHex[:16], time.Now().Unix()), nil +} + +func NewRunner(httpClient *http.Client, baseURL string, adminToken string, syncManager *csync.SyncManager) (*Runner, error) { + id, err := identity() + if err != nil { + return nil, err + } + cli := client.NewMessageClient(httpClient, baseURL, adminToken) + + return &Runner{ + client: cli, + lease: id, + syncManager: syncManager, + pending: make(map[int64]client.MessageResponse), + syncCh: make(chan struct{}), + }, nil +} + +func (r *Runner) Run(ctx context.Context, logger *slog.Logger) error { + logger.Info("lease", "id", r.lease) + go r.watch(ctx, logger) + + r.sync(ctx, r.lease, logger) + return ctx.Err() +} + +func (r *Runner) watch(ctx context.Context, logger *slog.Logger) { + for ctx.Err() == nil { + if err := r.runWatch(ctx); err != nil { + logger.Error("watch", "error", err) + } + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return + } + } +} + +func (r *Runner) sync(ctx context.Context, id string, logger *slog.Logger) { + for i := 0; ctx.Err() == nil; i++ { + if err := r.runOnceSync(context.Background(), id, logger); err != nil { + if err != errWait { + logger.Error("sync", "error", err) + } else { + select { + case <-r.syncCh: + case <-ctx.Done(): + return + } + continue + } + } + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return + } + } +} + +func (r *Runner) runWatch(ctx context.Context) error { + ch, err := r.client.WatchList(ctx) + if err != nil { + return err + } + + r.pendingMut.Lock() + clear(r.pending) + r.pendingMut.Unlock() + + for msg := range ch { + r.pendingMut.Lock() + if msg.Status == model.StatusPending { + r.pending[msg.MessageID] = msg + } else { + delete(r.pending, msg.MessageID) + } + r.pendingMut.Unlock() + + if msg.Status == model.StatusPending { + select { + case r.syncCh <- struct{}{}: + default: + } + } + } + return nil +} + +func (r *Runner) getPending() []client.MessageResponse { + r.pendingMut.Lock() + defer r.pendingMut.Unlock() + + var pendingMessages []client.MessageResponse + for _, msg := range r.pending { + if msg.Status == model.StatusPending { + pendingMessages = append(pendingMessages, msg) + } + } + + sort.Slice(pendingMessages, func(i, j int) bool { + return pendingMessages[i].Priority > pendingMessages[j].Priority + }) + + return pendingMessages +} + +var errWait = fmt.Errorf("no message received and no errors occurred") + +func (r *Runner) runOnceSync(ctx context.Context, id string, logger *slog.Logger) error { + var ( + err error + errs []error + resp client.MessageResponse + ) + + pending := r.getPending() + if len(pending) == 0 { + return errWait + } + + for _, msg := range pending { + resp, err = r.client.Consume(ctx, msg.MessageID, id) + if err != nil { + errs = append(errs, err) + } else { + break + } + } + + if resp.MessageID == 0 || resp.Content == "" { + if len(errs) == 0 { + return errWait + } + return errors.Join(errs...) + } + + var bmMut sync.Mutex + var bm []model.Blob + + var errCh = make(chan error, 1) + + go func() { + errCh <- r.syncManager.ImageWithCallback(ctx, resp.Content, func(blob string, progress, size int64) { + bmMut.Lock() + defer bmMut.Unlock() + for i, m := range bm { + if m.Digest == blob { + bm[i].Progress = progress + bm[i].Size = size + return + } + } + bm = append(bm, model.Blob{ + Digest: blob, + Progress: progress, + Size: size, + }) + }) + }() + + ticker := time.NewTicker(10 * time.Second) + + for { + select { + case <-ticker.C: + bmMut.Lock() + nbm := append([]model.Blob{}, bm...) + bmMut.Unlock() + + err := r.client.Heartbeat(ctx, resp.MessageID, client.HeartbeatRequest{ + Lease: id, + Data: model.MessageAttr{ + Blobs: nbm, + }, + }) + + if err != nil { + logger.Error("Heartbeat", "error", err) + } + + case err := <-errCh: + if err == nil { + return r.client.Complete(ctx, resp.MessageID, client.CompletedRequest{ + Lease: id, + }) + } + + if errors.Is(err, context.Canceled) { + return r.client.Cancel(ctx, resp.MessageID, client.CancelRequest{ + Lease: id, + }) + } + + return r.client.Failed(ctx, resp.MessageID, client.FailedRequest{ + Lease: id, + Data: model.MessageAttr{ + Error: err.Error(), + }, + }) + case <-ctx.Done(): + return r.client.Cancel(ctx, resp.MessageID, client.CancelRequest{ + Lease: id, + }) + } + } +} diff --git a/sync/sync.go b/sync/sync.go index 0c04084..ff7ef8d 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -117,6 +117,10 @@ func NewSyncManager(opts ...Option) (*SyncManager, error) { } func (c *SyncManager) Image(ctx context.Context, image string) error { + return c.ImageWithCallback(ctx, image, nil) +} + +func (c *SyncManager) ImageWithCallback(ctx context.Context, image string, blobFunc func(blob string, progress int64, size int64)) error { var regexTag *regexp.Regexp ref, err := reference.Parse(image) if err != nil { @@ -172,13 +176,13 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { c.uniqBlob.Add(dgst) blob := dgst.String() - + var gotSize int64 var subCaches []*cache.Cache for _, cache := range caches { stat, err := cache.StatBlob(ctx, blob) if err == nil { if size > 0 { - gotSize := stat.Size() + gotSize = stat.Size() if size == gotSize { continue } @@ -191,6 +195,9 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { } if len(subCaches) == 0 { + if blobFunc != nil { + blobFunc(blob, gotSize, gotSize) + } c.logger.Info("skip blob by cache", "image", image, "digest", dgst) return nil } @@ -201,6 +208,10 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { } defer f.Close() + if blobFunc != nil { + blobFunc(blob, 0, 0) + } + c.logger.Info("start sync blob", "image", image, "digest", dgst, "platform", pf) if len(subCaches) == 1 { @@ -208,6 +219,9 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { if err != nil { return fmt.Errorf("put blob failed: %w", err) } + if blobFunc != nil { + blobFunc(blob, n, n) + } c.logger.Info("finish sync blob", "image", image, "digest", dgst, "platform", pf, "size", n) return nil } @@ -241,6 +255,10 @@ func (c *SyncManager) Image(ctx context.Context, image string) error { wg.Wait() + if blobFunc != nil { + blobFunc(blob, n, n) + } + c.logger.Info("finish sync blob", "image", image, "digest", dgst, "platform", pf, "size", n) return nil }