Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions api/handler/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package handler

import (
"encoding/csv"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"

"github.com/gin-gonic/gin"
"opencsg.com/csghub-server/api/httpbase"
Expand All @@ -29,6 +31,11 @@ type ClusterHandler struct {
c component.ClusterComponent
}

const (
deployTimeLayout = "2006-01-02 15:04:05"
deployDateOnlyLayout = "2006-01-02"
)

// Getclusters godoc
// @Security ApiKey
// @Summary Get cluster list
Expand Down Expand Up @@ -146,6 +153,8 @@ func (h *ClusterHandler) GetDeploys(ctx *gin.Context) {
// @Produce text/csv
// @Param status query string false "status" default(all) Enums(all, running, stopped, deployfailed)
// @Param search query string false "search" default("")
// @Param start_time query string false "filter deploys created after or at this time"
// @Param end_time query string false "filter deploys created before or at this time"
// @Success 200 {string} string "CSV file"
// @Failure 400 {object} types.APIBadRequest "Bad request"
// @Failure 500 {object} types.APIInternalServerError "Internal server error"
Expand All @@ -165,6 +174,11 @@ func (h *ClusterHandler) GetDeploysReport(ctx *gin.Context) {
req.Status = []int{code.DeployFailed}
}
req.Query = ctx.Query("search")
if err := bindDeployDateRange(ctx, &req); err != nil {
slog.Error("Invalid date range for deploy report", slog.Any("error", err))
httpbase.BadRequest(ctx, err.Error())
return
}

filename := "deploys_report.csv"
ctx.Header("Content-Type", "text/csv; charset=utf-8")
Expand All @@ -189,7 +203,6 @@ func (h *ClusterHandler) GetDeploysReport(ctx *gin.Context) {
})
writer.Flush()

const timeLayout = "2006-01-02 15:04:05"
totalProcessed := 0

for {
Expand All @@ -207,7 +220,7 @@ func (h *ClusterHandler) GetDeploysReport(ctx *gin.Context) {
d.DeployName,
d.User.Username,
d.Resource,
d.CreateTime.Local().Format(timeLayout),
d.CreateTime.Local().Format(deployTimeLayout),
d.Status,
strconv.Itoa(d.TotalTimeInMin),
strconv.Itoa(d.TotalFeeInCents),
Expand Down Expand Up @@ -246,3 +259,43 @@ func (h *ClusterHandler) Update(ctx *gin.Context) {
}
httpbase.OK(ctx, result)
}

func bindDeployDateRange(ctx *gin.Context, req *types.DeployReq) error {
startTime := ctx.Query("start_time")
endTime := ctx.Query("end_time")
if startTime == "" && endTime == "" {
return nil
}
if startTime == "" || endTime == "" {
return fmt.Errorf("start_time and end_time must be provided together")
}
parsedStart, err := parseDeployQueryTime(startTime, false)
if err != nil {
return err
}
parsedEnd, err := parseDeployQueryTime(endTime, true)
if err != nil {
return err
}
req.StartTime = &parsedStart
req.EndTime = &parsedEnd
return nil
}

func parseDeployQueryTime(value string, isEnd bool) (time.Time, error) {
layouts := []string{deployTimeLayout, deployDateOnlyLayout}
for _, layout := range layouts {
parsed, err := time.ParseInLocation(layout, value, time.UTC)
if err != nil {
continue
}
if layout == deployDateOnlyLayout {
if isEnd {
parsed = parsed.Add(24*time.Hour - time.Nanosecond)
}
return parsed, nil
}
return parsed, nil
}
return time.Time{}, fmt.Errorf("invalid datetime format, use '%s' or '%s'", deployTimeLayout, deployDateOnlyLayout)
}
34 changes: 33 additions & 1 deletion api/handler/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -98,12 +99,25 @@ func Test_GetDeploysReport(t *testing.T) {
},
}

start := "2024-01-01 00:00:00"
end := "2024-01-31"
expectedStart, err := time.ParseInLocation(time.DateTime, start, time.UTC)
require.NoError(t, err)
endDate, err := time.ParseInLocation("2006-01-02", end, time.UTC)
require.NoError(t, err)
expectedEnd := endDate.Add(24*time.Hour - time.Nanosecond)
tester.mocks.clusterComponent.EXPECT().
GetDeploys(context.Background(), mock.Anything).
Run(func(_ context.Context, req types.DeployReq) {
require.NotNil(t, req.StartTime)
require.True(t, req.StartTime.Equal(expectedStart))
require.NotNil(t, req.EndTime)
require.True(t, req.EndTime.Equal(expectedEnd))
}).
Once().
Return(rows, len(rows), nil)

tester.Execute()
tester.WithQuery("start_time", start).WithQuery("end_time", end).Execute()

// assert response headers and body
resp := tester.Response()
Expand All @@ -117,3 +131,21 @@ func Test_GetDeploysReport(t *testing.T) {
require.Contains(t, body, "alice")
require.Contains(t, body, "bob")
}

func Test_GetDeploysReport_InvalidDateRange(t *testing.T) {
tester := newClusterTester(t).withHandlerFunc(func(clusterHandler *ClusterHandler) gin.HandlerFunc {
return clusterHandler.GetDeploysReport
})

tester.WithQuery("start_time", "2024-01-01").Execute()
tester.ResponseEqSimple(t, http.StatusBadRequest, httpbase.R{Msg: "start_time and end_time must be provided together"})
}

func Test_GetDeploysReport_InvalidFormat(t *testing.T) {
tester := newClusterTester(t).withHandlerFunc(func(clusterHandler *ClusterHandler) gin.HandlerFunc {
return clusterHandler.GetDeploysReport
})

tester.WithQuery("start_time", "invalid").WithQuery("end_time", "2024-01-01").Execute()
tester.ResponseEqSimple(t, http.StatusBadRequest, httpbase.R{Msg: "invalid datetime format, use '2006-01-02 15:04:05' or '2006-01-02'"})
}
49 changes: 37 additions & 12 deletions api/router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
middlewareCollection.Auth.NeedPhoneVerified = middleware.NeedPhoneVerified(config)
middlewareCollection.Repo.RepoExists = middleware.RepoExists(config)
middlewareCollection.License.Check = middleware.CheckLicense(config)

middlewareCollection.API.RateLimter = middleware.RateLimiter(config, middleware.WithTimeBucketRateLimter(config), middleware.WithIPCheck())
//add router for golang pprof
debugGroup := r.Group("/debug", middlewareCollection.Auth.NeedAPIKey)
pprof.RouteRegister(debugGroup, "pprof")
Expand Down Expand Up @@ -283,7 +283,7 @@
tokenGroup.PUT("/:app/:token_name", userProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name"))
tokenGroup.DELETE("/:app/:token_name", userProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name"))
// check token info
tokenGroup.GET("/:token_value", middlewareCollection.Auth.NeedAPIKey, userProxyHandler.ProxyToApi("/api/v1/token/%s", "token_value"))
tokenGroup.GET("/:token_value", userProxyHandler.ProxyToApi("/api/v1/token/%s", "token_value"))
}

sshKeyHandler, err := handler.NewSSHKeyHandler(config)
Expand All @@ -310,6 +310,7 @@
apiGroup.POST("/jwt/token", middlewareCollection.Auth.NeedAPIKey, userProxyHandler.Proxy)
apiGroup.GET("/jwt/:token", middlewareCollection.Auth.NeedAPIKey, userProxyHandler.ProxyToApi("/api/v1/jwt/%s", "token"))
apiGroup.GET("/users", userProxyHandler.Proxy)
apiGroup.GET("/users/stream-export", middlewareCollection.Auth.NeedAdmin, userProxyHandler.Proxy)

// callback
callbackCtrl, err := callback.NewGitCallbackHandler(config)
Expand Down Expand Up @@ -545,6 +546,19 @@
return nil, fmt.Errorf("error creating webhook routes: %w", err)
}

// agent
agentHandler, err := handler.NewAgentHandler(config)
if err != nil {
return nil, fmt.Errorf("error creating agent handler: %w", err)
}
createAgentRoutes(apiGroup, middlewareCollection, agentHandler)

Check failure on line 554 in api/router/api.go

View workflow job for this annotation

GitHub Actions / lint

undefined: createAgentRoutes (typecheck)

finetuneJobHandler, err := handler.NewFinetuneHandler(config)
if err != nil {
return nil, fmt.Errorf("error creating finetune job handler: %w", err)
}
createFinetuneRoutes(apiGroup, middlewareCollection, finetuneJobHandler)

return r, nil
}

Expand All @@ -559,6 +573,16 @@
}
}

func createFinetuneRoutes(apiGroup *gin.RouterGroup, middlewareCollection middleware.MiddlewareCollection, finetuneJobHandler *handler.FinetuneHandler) {
ftGroup := apiGroup.Group("/finetunes")
ftGroup.Use(middlewareCollection.Auth.NeedLogin)
{
ftGroup.POST("", finetuneJobHandler.RunFinetuneJob)
ftGroup.GET("/:id", finetuneJobHandler.GetFinetuneJob)
ftGroup.DELETE("/:id", finetuneJobHandler.DeleteFinetuneJob)
}
}

func createModelRoutes(config *config.Config,
apiGroup *gin.RouterGroup,
middlewareCollection middleware.MiddlewareCollection,
Expand All @@ -571,7 +595,7 @@
modelsGroup := apiGroup.Group("/models")
modelsGroup.Use(middleware.RepoType(types.ModelRepo), middlewareCollection.Repo.RepoExists)
{
modelsGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, modelHandler.Create)
modelsGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, modelHandler.Create)
modelsGroup.GET("", cache.Cache(memoryStore, time.Minute, middleware.CacheStrategyTrendingRepos()), modelHandler.Index)
modelsGroup.PUT("/:namespace/:name", middlewareCollection.Auth.NeedLogin, modelHandler.Update)
modelsGroup.DELETE("/:namespace/:name", middlewareCollection.Auth.NeedLogin, modelHandler.Delete)
Expand Down Expand Up @@ -719,7 +743,7 @@
// must login
datasetsGroup.Use(middleware.RepoType(types.DatasetRepo), middlewareCollection.Repo.RepoExists)
{
datasetsGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, dsHandler.Create)
datasetsGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, dsHandler.Create)
datasetsGroup.PUT("/:namespace/:name", middleware.MustLogin(), dsHandler.Update)
datasetsGroup.DELETE("/:namespace/:name", middleware.MustLogin(), dsHandler.Delete)
datasetsGroup.GET("/:namespace/:name", dsHandler.Show)
Expand Down Expand Up @@ -773,7 +797,7 @@
codesGroup := apiGroup.Group("/codes")
codesGroup.Use(middleware.RepoType(types.CodeRepo), middlewareCollection.Repo.RepoExists)
{
codesGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, codeHandler.Create)
codesGroup.POST("", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, codeHandler.Create)
codesGroup.GET("", codeHandler.Index)
codesGroup.PUT("/:namespace/:name", middlewareCollection.Auth.NeedLogin, codeHandler.Update)
codesGroup.DELETE("/:namespace/:name", middlewareCollection.Auth.NeedLogin, codeHandler.Delete)
Expand Down Expand Up @@ -855,9 +879,9 @@
{
// list all spaces
spaces.GET("", spaceHandler.Index)
spaces.POST("", middlewareCollection.Auth.NeedPhoneVerified, spaceHandler.Create)
spaces.POST("", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, spaceHandler.Create)
// show a user or org's space
spaces.GET("/:namespace/:name", middlewareCollection.Auth.NeedLogin, spaceHandler.Show)
spaces.GET("/:namespace/:name", middlewareCollection.Auth.NeedLogin, middlewareCollection.API.RateLimter, spaceHandler.Show)
spaces.PUT("/:namespace/:name", middlewareCollection.Auth.NeedLogin, spaceHandler.Update)
spaces.DELETE("/:namespace/:name", middlewareCollection.Auth.NeedLogin, spaceHandler.Delete)
// depoly and start running the space
Expand Down Expand Up @@ -979,6 +1003,7 @@
{
apiGroup.GET("/user/:username/run/:repo_type", middlewareCollection.Auth.UserMatch, userHandler.GetRunDeploys)
apiGroup.GET("/user/:username/finetune/instances", middlewareCollection.Auth.UserMatch, userHandler.GetFinetuneInstances)
apiGroup.GET("/user/:username/finetune/jobs", middlewareCollection.Auth.UserMatch, userHandler.GetUserFinetunes)
// User evaluations
apiGroup.GET("/user/:username/evaluations", middlewareCollection.Auth.UserMatch, userHandler.GetEvaluations)
// User notebooks
Expand Down Expand Up @@ -1114,14 +1139,14 @@
}

func createDiscussionRoutes(apiGroup *gin.RouterGroup, middlewareCollection middleware.MiddlewareCollection, discussionHandler *handler.DiscussionHandler) {
apiGroup.POST("/:repo_type/:namespace/:name/discussions", middlewareCollection.Auth.NeedPhoneVerified, discussionHandler.CreateRepoDiscussion)
apiGroup.POST("/:repo_type/:namespace/:name/discussions", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, discussionHandler.CreateRepoDiscussion)
apiGroup.GET("/:repo_type/:namespace/:name/discussions", discussionHandler.ListRepoDiscussions)
apiGroup.GET("/discussions/:id", discussionHandler.ShowDiscussion)
apiGroup.PUT("/discussions/:id", middlewareCollection.Auth.NeedLogin, discussionHandler.UpdateDiscussion)
apiGroup.PUT("/discussions/:id", middlewareCollection.Auth.NeedLogin, middlewareCollection.API.RateLimter, discussionHandler.UpdateDiscussion)
apiGroup.DELETE("/discussions/:id", middlewareCollection.Auth.NeedLogin, discussionHandler.DeleteDiscussion)
apiGroup.POST("/discussions/:id/comments", middlewareCollection.Auth.NeedPhoneVerified, discussionHandler.CreateDiscussionComment)
apiGroup.GET("/discussions/:id/comments", discussionHandler.ListDiscussionComments)
apiGroup.PUT("/discussions/:id/comments/:comment_id", middlewareCollection.Auth.NeedLogin, discussionHandler.UpdateComment)
apiGroup.POST("/discussions/:id/comments", middlewareCollection.Auth.NeedPhoneVerified, middlewareCollection.API.RateLimter, discussionHandler.CreateDiscussionComment)
apiGroup.GET("/discussions/:id/comments", middlewareCollection.API.RateLimter, discussionHandler.ListDiscussionComments)
apiGroup.PUT("/discussions/:id/comments/:comment_id", middlewareCollection.Auth.NeedLogin, middlewareCollection.API.RateLimter, discussionHandler.UpdateComment)
apiGroup.DELETE("/discussions/:id/comments/:comment_id", middlewareCollection.Auth.NeedLogin, discussionHandler.DeleteComment)
}

Expand Down
Loading
Loading