Skip to content

Commit

Permalink
Improve context header.
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Sep 7, 2023
1 parent eb4369e commit abd084b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 31 deletions.
9 changes: 4 additions & 5 deletions src/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ func NewApp() *gear.App {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

headers := http.Header{}
headers.Set("x-auth-user", util.JARVIS.String())
headers.Set("x-auth-app", util.JARVIS.String())
ctxHeader := util.ContextHTTPHeader(headers)
ctx = gear.CtxWith[util.ContextHTTPHeader](ctx, &ctxHeader)
h := http.Header{}
h.Set("x-auth-user", util.JARVIS.String())
h.Set("x-auth-app", util.JARVIS.String())
ctx = gear.CtxWith[util.CtxHeader](ctx, util.Ptr(util.CtxHeader(h)))
if err := blls.Jarvis.InitApp(ctx, app); err != nil {
return err
}
Expand Down
14 changes: 14 additions & 0 deletions src/api/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package api

import (
"net/http"

"github.com/teambition/gear"

"github.com/yiwen-ai/yiwen-api/src/bll"
Expand Down Expand Up @@ -45,6 +47,18 @@ func todo(ctx *gear.Context) error {
func newRouters(apis *APIs) []*gear.Router {

router := gear.NewRouter()
router.Use(func(ctx *gear.Context) error {
h := http.Header{}
// inject headers into context for base service
util.CopyHeader(h, ctx.Req.Header,
"x-real-ip",
"x-request-id",
)

ctx.WithContext(gear.CtxWith[util.CtxHeader](ctx.Context(), util.Ptr(util.CtxHeader(h))))
return nil
})

router.Get("/healthz", apis.Healthz.Get)

// 允许匿名访问
Expand Down
3 changes: 1 addition & 2 deletions src/api/scraping.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func (a *Scraping) Convert(ctx *gear.Context) error {
return err
}

header := gear.CtxValue[util.ContextHTTPHeader](ctx)
http.Header(*header).Set(gear.HeaderContentType, mtype)
util.HeaderFromCtx(ctx).Set(gear.HeaderContentType, mtype)
output, err := a.blls.Webscraper.Convert(ctx, buf, mtype)
if err != nil {
return gear.ErrInternalServerError.From(err)
Expand Down
21 changes: 5 additions & 16 deletions src/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package middleware

import (
"context"
"net/http"
"strconv"
"strings"

Expand Down Expand Up @@ -72,18 +71,13 @@ func (m AuthLevel) Auth(ctx *gear.Context) error {
log["uid"] = sess.UserID

ctx.Req.Header.Set("x-auth-user", sess.UserID.String())
ctxHeader := make(http.Header)
// inject auth headers into context for base service
util.CopyHeader(ctxHeader, ctx.Req.Header,
"x-real-ip",
"x-request-id",
util.CopyHeader(util.HeaderFromCtx(ctx), ctx.Req.Header,
"x-auth-user",
"x-language",
)

cctx := gear.CtxWith[Session](ctx.Context(), sess)
cheader := util.ContextHTTPHeader(ctxHeader)
ctx.WithContext(gear.CtxWith[util.ContextHTTPHeader](cctx, &cheader))
ctx.WithContext(gear.CtxWith[Session](ctx.Context(), sess))
return nil
}

Expand All @@ -98,20 +92,15 @@ func (m AuthLevel) Auth(ctx *gear.Context) error {
}
log["aud"] = sess.AppID

ctxHeader := make(http.Header)
// inject auth headers into context for base service
util.CopyHeader(ctxHeader, ctx.Req.Header,
"x-real-ip",
"x-request-id",
util.CopyHeader(util.HeaderFromCtx(ctx), ctx.Req.Header,
"x-auth-user",
"x-auth-user-rating",
"x-auth-app",
"x-language",
)

cctx := gear.CtxWith[Session](ctx.Context(), sess)
cheader := util.ContextHTTPHeader(ctxHeader)
ctx.WithContext(gear.CtxWith[util.ContextHTTPHeader](cctx, &cheader))
ctx.WithContext(gear.CtxWith[Session](ctx.Context(), sess))
return nil
}

Expand All @@ -120,7 +109,7 @@ func WithGlobalCtx(ctx *gear.Context) context.Context {

if sess := gear.CtxValue[Session](ctx); sess != nil {
gctx = gear.CtxWith[Session](gctx, sess)
gctx = gear.CtxWith[util.ContextHTTPHeader](gctx, gear.CtxValue[util.ContextHTTPHeader](ctx))
gctx = gear.CtxWith[util.CtxHeader](gctx, gear.CtxValue[util.CtxHeader](ctx))
}

return gctx
Expand Down
21 changes: 16 additions & 5 deletions src/util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@ var HTTPClient = &http.Client{
Timeout: time.Second * 5,
}

type ContextHTTPHeader http.Header
type CtxHeader http.Header

func (ch CtxHeader) Header() http.Header {
return http.Header(ch)
}

func HeaderFromCtx(ctx context.Context) http.Header {
if ch := gear.CtxValue[CtxHeader](ctx); ch != nil {
return ch.Header()
}
return nil
}

func RequestJSON(ctx context.Context, cli *http.Client, method, api string, input, output any) error {
if ctx.Err() != nil {
Expand Down Expand Up @@ -94,8 +105,8 @@ func RequestJSON(ctx context.Context, cli *http.Client, method, api string, inpu
req.Header.Set(gear.HeaderContentType, gear.MIMEApplicationJSON)
}

if header := gear.CtxValue[ContextHTTPHeader](ctx); header != nil {
CopyHeader(req.Header, http.Header(*header))
if header := HeaderFromCtx(ctx); header != nil {
CopyHeader(req.Header, header)
}

rid := req.Header.Get(gear.HeaderXRequestID)
Expand Down Expand Up @@ -155,8 +166,8 @@ func RequestCBOR(ctx context.Context, cli *http.Client, method, api string, inpu
req.Header.Set(gear.HeaderContentType, gear.MIMEApplicationCBOR)
}

if header := gear.CtxValue[ContextHTTPHeader](ctx); header != nil {
CopyHeader(req.Header, http.Header(*header))
if header := HeaderFromCtx(ctx); header != nil {
CopyHeader(req.Header, header)
}

rid := req.Header.Get(gear.HeaderXRequestID)
Expand Down
6 changes: 3 additions & 3 deletions tests/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ func TestAPI(t *testing.T) {

var targetUrl string = "https://datatracker.ietf.org/doc/html/rfc8949"

authheaders := util.ContextHTTPHeader{}
authheaders := util.CtxHeader{}
uuid := util.NewUUID()
fmt.Printf("UUID: %s\n", uuid.String())

http.Header(authheaders).Set("x-request-id", uuid.String())
http.Header(authheaders).Set("cookie", cookie)

ctx := gear.CtxWith[util.ContextHTTPHeader](context.Background(), &authheaders)
ctx := gear.CtxWith[util.CtxHeader](context.Background(), &authheaders)
sess, err := GetToken(ctx)
require.NoError(t, err)

http.Header(authheaders).Set("authorization", "Bearer "+sess.AccessToken)
ctx = gear.CtxWith[util.ContextHTTPHeader](context.Background(), &authheaders)
ctx = gear.CtxWith[util.CtxHeader](context.Background(), &authheaders)

myGroups, err := ListMyGroups(ctx)
require.NoError(t, err)
Expand Down

0 comments on commit abd084b

Please sign in to comment.