diff --git a/src/api/app.go b/src/api/app.go index 11a8c01..e233b5b 100644 --- a/src/api/app.go +++ b/src/api/app.go @@ -38,11 +38,10 @@ func NewApp() *gear.App { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - headers := http.Header{} - headers.Set("x-auth-user", util.JARVIS.String()) - headers.Set("x-auth-app", util.JARVIS.String()) - ctxHeader := util.ContextHTTPHeader(headers) - ctx = gear.CtxWith[util.ContextHTTPHeader](ctx, &ctxHeader) + h := http.Header{} + h.Set("x-auth-user", util.JARVIS.String()) + h.Set("x-auth-app", util.JARVIS.String()) + ctx = gear.CtxWith[util.CtxHeader](ctx, util.Ptr(util.CtxHeader(h))) if err := blls.Jarvis.InitApp(ctx, app); err != nil { return err } diff --git a/src/api/router.go b/src/api/router.go index ab6771c..a0075b3 100644 --- a/src/api/router.go +++ b/src/api/router.go @@ -1,6 +1,8 @@ package api import ( + "net/http" + "github.com/teambition/gear" "github.com/yiwen-ai/yiwen-api/src/bll" @@ -45,6 +47,18 @@ func todo(ctx *gear.Context) error { func newRouters(apis *APIs) []*gear.Router { router := gear.NewRouter() + router.Use(func(ctx *gear.Context) error { + h := http.Header{} + // inject headers into context for base service + util.CopyHeader(h, ctx.Req.Header, + "x-real-ip", + "x-request-id", + ) + + ctx.WithContext(gear.CtxWith[util.CtxHeader](ctx.Context(), util.Ptr(util.CtxHeader(h)))) + return nil + }) + router.Get("/healthz", apis.Healthz.Get) // 允许匿名访问 diff --git a/src/api/scraping.go b/src/api/scraping.go index 8fd125b..72bff21 100644 --- a/src/api/scraping.go +++ b/src/api/scraping.go @@ -65,8 +65,7 @@ func (a *Scraping) Convert(ctx *gear.Context) error { return err } - header := gear.CtxValue[util.ContextHTTPHeader](ctx) - http.Header(*header).Set(gear.HeaderContentType, mtype) + util.HeaderFromCtx(ctx).Set(gear.HeaderContentType, mtype) output, err := a.blls.Webscraper.Convert(ctx, buf, mtype) if err != nil { return gear.ErrInternalServerError.From(err) diff --git a/src/middleware/auth.go b/src/middleware/auth.go index 8ea8bc0..b57104f 100644 --- a/src/middleware/auth.go +++ b/src/middleware/auth.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "net/http" "strconv" "strings" @@ -72,18 +71,13 @@ func (m AuthLevel) Auth(ctx *gear.Context) error { log["uid"] = sess.UserID ctx.Req.Header.Set("x-auth-user", sess.UserID.String()) - ctxHeader := make(http.Header) // inject auth headers into context for base service - util.CopyHeader(ctxHeader, ctx.Req.Header, - "x-real-ip", - "x-request-id", + util.CopyHeader(util.HeaderFromCtx(ctx), ctx.Req.Header, "x-auth-user", "x-language", ) - cctx := gear.CtxWith[Session](ctx.Context(), sess) - cheader := util.ContextHTTPHeader(ctxHeader) - ctx.WithContext(gear.CtxWith[util.ContextHTTPHeader](cctx, &cheader)) + ctx.WithContext(gear.CtxWith[Session](ctx.Context(), sess)) return nil } @@ -98,20 +92,15 @@ func (m AuthLevel) Auth(ctx *gear.Context) error { } log["aud"] = sess.AppID - ctxHeader := make(http.Header) // inject auth headers into context for base service - util.CopyHeader(ctxHeader, ctx.Req.Header, - "x-real-ip", - "x-request-id", + util.CopyHeader(util.HeaderFromCtx(ctx), ctx.Req.Header, "x-auth-user", "x-auth-user-rating", "x-auth-app", "x-language", ) - cctx := gear.CtxWith[Session](ctx.Context(), sess) - cheader := util.ContextHTTPHeader(ctxHeader) - ctx.WithContext(gear.CtxWith[util.ContextHTTPHeader](cctx, &cheader)) + ctx.WithContext(gear.CtxWith[Session](ctx.Context(), sess)) return nil } @@ -120,7 +109,7 @@ func WithGlobalCtx(ctx *gear.Context) context.Context { if sess := gear.CtxValue[Session](ctx); sess != nil { gctx = gear.CtxWith[Session](gctx, sess) - gctx = gear.CtxWith[util.ContextHTTPHeader](gctx, gear.CtxValue[util.ContextHTTPHeader](ctx)) + gctx = gear.CtxWith[util.CtxHeader](gctx, gear.CtxValue[util.CtxHeader](ctx)) } return gctx diff --git a/src/util/http.go b/src/util/http.go index 9fd5a77..4791285 100644 --- a/src/util/http.go +++ b/src/util/http.go @@ -64,7 +64,18 @@ var HTTPClient = &http.Client{ Timeout: time.Second * 5, } -type ContextHTTPHeader http.Header +type CtxHeader http.Header + +func (ch CtxHeader) Header() http.Header { + return http.Header(ch) +} + +func HeaderFromCtx(ctx context.Context) http.Header { + if ch := gear.CtxValue[CtxHeader](ctx); ch != nil { + return ch.Header() + } + return nil +} func RequestJSON(ctx context.Context, cli *http.Client, method, api string, input, output any) error { if ctx.Err() != nil { @@ -94,8 +105,8 @@ func RequestJSON(ctx context.Context, cli *http.Client, method, api string, inpu req.Header.Set(gear.HeaderContentType, gear.MIMEApplicationJSON) } - if header := gear.CtxValue[ContextHTTPHeader](ctx); header != nil { - CopyHeader(req.Header, http.Header(*header)) + if header := HeaderFromCtx(ctx); header != nil { + CopyHeader(req.Header, header) } rid := req.Header.Get(gear.HeaderXRequestID) @@ -155,8 +166,8 @@ func RequestCBOR(ctx context.Context, cli *http.Client, method, api string, inpu req.Header.Set(gear.HeaderContentType, gear.MIMEApplicationCBOR) } - if header := gear.CtxValue[ContextHTTPHeader](ctx); header != nil { - CopyHeader(req.Header, http.Header(*header)) + if header := HeaderFromCtx(ctx); header != nil { + CopyHeader(req.Header, header) } rid := req.Header.Get(gear.HeaderXRequestID) diff --git a/tests/api_test.go b/tests/api_test.go index d10ac84..d1ee194 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -23,19 +23,19 @@ func TestAPI(t *testing.T) { var targetUrl string = "https://datatracker.ietf.org/doc/html/rfc8949" - authheaders := util.ContextHTTPHeader{} + authheaders := util.CtxHeader{} uuid := util.NewUUID() fmt.Printf("UUID: %s\n", uuid.String()) http.Header(authheaders).Set("x-request-id", uuid.String()) http.Header(authheaders).Set("cookie", cookie) - ctx := gear.CtxWith[util.ContextHTTPHeader](context.Background(), &authheaders) + ctx := gear.CtxWith[util.CtxHeader](context.Background(), &authheaders) sess, err := GetToken(ctx) require.NoError(t, err) http.Header(authheaders).Set("authorization", "Bearer "+sess.AccessToken) - ctx = gear.CtxWith[util.ContextHTTPHeader](context.Background(), &authheaders) + ctx = gear.CtxWith[util.CtxHeader](context.Background(), &authheaders) myGroups, err := ListMyGroups(ctx) require.NoError(t, err)