Skip to content

Commit eb31574

Browse files
committed
Bring some auth utils here
1 parent fd26b27 commit eb31574

File tree

3 files changed

+131
-6
lines changed

3 files changed

+131
-6
lines changed

auth/auth.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"reflect"
6+
7+
"golang.org/x/xerrors"
8+
)
9+
10+
type Permission string
11+
12+
type permKey int
13+
14+
var permCtxKey permKey
15+
16+
func WithPerm(ctx context.Context, perms []Permission) context.Context {
17+
return context.WithValue(ctx, permCtxKey, perms)
18+
}
19+
20+
func HasPerm(ctx context.Context, defaultPerms []Permission, perm Permission) bool {
21+
callerPerms, ok := ctx.Value(permCtxKey).([]Permission)
22+
if !ok {
23+
callerPerms = defaultPerms
24+
}
25+
26+
for _, callerPerm := range callerPerms {
27+
if callerPerm == perm {
28+
return true
29+
}
30+
}
31+
return false
32+
}
33+
34+
func PermissionedProxy(validPerms, defaultPerms []Permission, in interface{}, out interface{}) {
35+
rint := reflect.ValueOf(out).Elem()
36+
ra := reflect.ValueOf(in)
37+
38+
for f := 0; f < rint.NumField(); f++ {
39+
field := rint.Type().Field(f)
40+
requiredPerm := Permission(field.Tag.Get("perm"))
41+
if requiredPerm == "" {
42+
panic("missing 'perm' tag on " + field.Name) // ok
43+
}
44+
45+
// Validate perm tag
46+
ok := false
47+
for _, perm := range validPerms {
48+
if requiredPerm == perm {
49+
ok = true
50+
break
51+
}
52+
}
53+
if !ok {
54+
panic("unknown 'perm' tag on " + field.Name) // ok
55+
}
56+
57+
fn := ra.MethodByName(field.Name)
58+
59+
rint.Field(f).Set(reflect.MakeFunc(field.Type, func(args []reflect.Value) (results []reflect.Value) {
60+
ctx := args[0].Interface().(context.Context)
61+
if HasPerm(ctx, defaultPerms, requiredPerm) {
62+
return fn.Call(args)
63+
}
64+
65+
err := xerrors.Errorf("missing permission to invoke '%s' (need '%s')", field.Name, requiredPerm)
66+
rerr := reflect.ValueOf(&err).Elem()
67+
68+
if field.Type.NumOut() == 2 {
69+
return []reflect.Value{
70+
reflect.Zero(field.Type.Out(0)),
71+
rerr,
72+
}
73+
} else {
74+
return []reflect.Value{rerr}
75+
}
76+
}))
77+
78+
}
79+
}

auth/handler.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"strings"
7+
8+
logging "github.com/ipfs/go-log/v2"
9+
)
10+
11+
var log = logging.Logger("auth")
12+
13+
type Handler struct {
14+
Verify func(ctx context.Context, token string) ([]Permission, error)
15+
Next http.HandlerFunc
16+
}
17+
18+
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
19+
ctx := r.Context()
20+
21+
token := r.Header.Get("Authorization")
22+
if token == "" {
23+
token = r.FormValue("token")
24+
if token != "" {
25+
token = "Bearer " + token
26+
}
27+
}
28+
29+
if token != "" {
30+
if !strings.HasPrefix(token, "Bearer ") {
31+
log.Warn("missing Bearer prefix in auth header")
32+
w.WriteHeader(401)
33+
return
34+
}
35+
token = strings.TrimPrefix(token, "Bearer ")
36+
37+
allow, err := h.Verify(ctx, token)
38+
if err != nil {
39+
log.Warnf("JWT Verification failed: %s", err)
40+
w.WriteHeader(401)
41+
return
42+
}
43+
44+
ctx = WithPerm(ctx, allow)
45+
}
46+
47+
h.Next(w, r.WithContext(ctx))
48+
}

metrics/metrics.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package metrics
22

3-
43
import (
54
"go.opencensus.io/stats"
65
"go.opencensus.io/stats/view"
@@ -9,14 +8,14 @@ import (
98

109
// Global Tags
1110
var (
12-
RPCMethod, _ = tag.NewKey("method")
11+
RPCMethod, _ = tag.NewKey("method")
1312
)
1413

1514
// Measures
1615
var (
17-
RPCInvalidMethod = stats.Int64("rpc/invalid_method", "Total number of invalid RPC methods called", stats.UnitDimensionless)
18-
RPCRequestError = stats.Int64("rpc/request_error", "Total number of request errors handled", stats.UnitDimensionless)
19-
RPCResponseError = stats.Int64("rpc/response_error", "Total number of responses errors handled", stats.UnitDimensionless)
16+
RPCInvalidMethod = stats.Int64("rpc/invalid_method", "Total number of invalid RPC methods called", stats.UnitDimensionless)
17+
RPCRequestError = stats.Int64("rpc/request_error", "Total number of request errors handled", stats.UnitDimensionless)
18+
RPCResponseError = stats.Int64("rpc/response_error", "Total number of responses errors handled", stats.UnitDimensionless)
2019
)
2120

2221
var (
@@ -44,4 +43,3 @@ var DefaultViews = []*view.View{
4443
RPCRequestErrorView,
4544
RPCResponseErrorView,
4645
}
47-

0 commit comments

Comments
 (0)