Skip to content

multi: extract session ID from context #1045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions firewall/privacy_mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ func (p *PrivacyMapper) Intercept(ctx context.Context,
"interception request: %v", err)
}

sessionID, err := session.IDFromMacaroon(ri.Macaroon)
sessionID, err := ri.SessionID.UnwrapOrErr(
fmt.Errorf("no session ID found in request info"),
)
if err != nil {
return nil, fmt.Errorf("could not extract ID from macaroon")
return nil, err
}

log.Tracef("PrivacyMapper: Intercepting %v", ri)
Expand Down
9 changes: 9 additions & 0 deletions firewall/privacy_mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/rpcperms"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"gopkg.in/macaroon-bakery.v2/bakery"
"gopkg.in/macaroon.v2"
Expand Down Expand Up @@ -907,6 +908,9 @@ func TestPrivacyMapper(t *testing.T) {
rawMsg, err := proto.Marshal(test.msg)
require.NoError(t, err)

md := make(metadata.MD)
session.AddToGRPCMetadata(md, sessionID)

interceptReq := &rpcperms.InterceptionRequest{
Type: test.msgType,
Macaroon: mac,
Expand All @@ -916,6 +920,7 @@ func TestPrivacyMapper(t *testing.T) {
ProtoTypeName: string(
proto.MessageName(test.msg),
),
CtxMetadataPairs: md,
}

mwReq, err := interceptReq.ToRPC(1, 2)
Expand Down Expand Up @@ -1006,6 +1011,9 @@ func TestPrivacyMapper(t *testing.T) {
amounts := make([]uint64, numSamples)
timestamps := make([]uint64, numSamples)

md := make(metadata.MD)
session.AddToGRPCMetadata(md, sessionID)

for i := 0; i < numSamples; i++ {
interceptReq := &rpcperms.InterceptionRequest{
Type: rpcperms.TypeResponse,
Expand All @@ -1016,6 +1024,7 @@ func TestPrivacyMapper(t *testing.T) {
ProtoTypeName: string(
proto.MessageName(msg),
),
CtxMetadataPairs: md,
}

mwReq, err := interceptReq.ToRPC(1, 2)
Expand Down
18 changes: 18 additions & 0 deletions firewall/request_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import (
"fmt"
"strings"

"github.com/lightninglabs/lightning-terminal/session"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnrpc"
"google.golang.org/grpc/metadata"
"gopkg.in/macaroon.v2"
)

Expand All @@ -25,6 +28,7 @@ const (
// RequestInfo stores the parsed representation of an incoming RPC middleware
// request.
type RequestInfo struct {
SessionID fn.Option[session.ID]
MsgID uint64
RequestID uint64
MWRequestType string
Expand Down Expand Up @@ -76,8 +80,22 @@ func NewInfoFromRequest(req *lnrpc.RPCMiddlewareRequest) (*RequestInfo, error) {
return nil, fmt.Errorf("invalid request type: %T", t)
}

md := make(metadata.MD)
for k, vs := range req.MetadataPairs {
for _, v := range vs.Values {
md.Append(k, v)
}
}

sessionID, err := session.FromGRPCMetadata(md)
if err != nil {
return nil, fmt.Errorf("error extracting session ID "+
"from request: %v", err)
}

ri.MsgID = req.MsgId
ri.RequestID = req.RequestId
ri.SessionID = sessionID

// If there is no macaroon in the request, then there is nothing left
// to parse.
Expand Down
1 change: 1 addition & 0 deletions firewall/request_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo,
}

actionReq := &firewalldb.AddActionReq{
SessionID: ri.SessionID,
MacaroonIdentifier: macaroonID,
RPCMethod: ri.URI,
}
Expand Down
6 changes: 4 additions & 2 deletions firewall/rule_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ func (r *RuleEnforcer) Intercept(ctx context.Context,
func (r *RuleEnforcer) handleRequest(ctx context.Context,
ri *RequestInfo) (proto.Message, error) {

sessionID, err := session.IDFromMacaroon(ri.Macaroon)
sessionID, err := ri.SessionID.UnwrapOrErr(
fmt.Errorf("no session ID found in request info"),
)
if err != nil {
return nil, fmt.Errorf("could not extract ID from macaroon")
return nil, err
}

rules, err := r.collectEnforcers(ctx, ri, sessionID)
Expand Down
57 changes: 57 additions & 0 deletions session/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package session

import (
"encoding/hex"
"fmt"

"github.com/lightningnetwork/lnd/fn"
"google.golang.org/grpc/metadata"
)

// contextKey is a struct that is used as a key for storing session IDs
// in a context. Using this unexported type prevents collisions with other
// context keys that may be used in the same context. However, this only
// applies if the context is passed around in the same binary and not if the
// value is converted to grpc metadata and sent over the wire. In that case,
// we need to use a string key to avoid collisions with other metadata keys.
type contextKey struct {
name string
}

// sessionIDCtxKey is the context key used to store the session ID in
// a context. The key is a string to avoid collisions with other context values
// that may also be included in grpc metadata which is why we add the 'lit'
// prefix.
var sessionIDCtxKey = contextKey{"lit_session_id"}

// FromGRPCMetadata extracts the session ID from the given gRPC metadata kv
// pairs if one is found.
func FromGRPCMetadata(md metadata.MD) (fn.Option[ID], error) {
val := md.Get(sessionIDCtxKey.name)
if len(val) == 0 {
return fn.None[ID](), nil
}

if len(val) != 1 {
return fn.None[ID](), fmt.Errorf("more than one session ID "+
"found in gRPC metadata: %v", val)
}

b, err := hex.DecodeString(val[0])
if err != nil {
return fn.None[ID](), err
}

sessID, err := IDFromBytes(b)
if err != nil {
return fn.None[ID](), err
}

return fn.Some(sessID), nil
}

// AddToGRPCMetadata adds the session ID to the given gRPC metadata kv pairs.
// The session ID is encoded as a hex string.
func AddToGRPCMetadata(md metadata.MD, id ID) {
md.Set(sessionIDCtxKey.name, hex.EncodeToString(id[:]))
}
5 changes: 3 additions & 2 deletions session/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (

type sessionID [33]byte

type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server
type GRPCServerCreator func(sessionID ID,
opts ...grpc.ServerOption) *grpc.Server

type mailboxSession struct {
server *grpc.Server
Expand Down Expand Up @@ -70,7 +71,7 @@ func (m *mailboxSession) start(session *Session,
}

noiseConn := mailbox.NewNoiseGrpcConn(keys)
m.server = serverCreator(grpc.Creds(noiseConn))
m.server = serverCreator(session.ID, grpc.Creds(noiseConn))

m.wg.Add(1)
go m.run(mailboxServer)
Expand Down
74 changes: 72 additions & 2 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/macaroons"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"gopkg.in/macaroon-bakery.v2/bakery"
"gopkg.in/macaroon-bakery.v2/bakery/checkers"
"gopkg.in/macaroon.v2"
Expand Down Expand Up @@ -77,10 +78,23 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
// actual mailbox server that spins up the Terminal Connect server
// interface.
server := session.NewServer(
func(opts ...grpc.ServerOption) *grpc.Server {
allOpts := append(cfg.grpcOptions, opts...)
func(id session.ID, opts ...grpc.ServerOption) *grpc.Server {
// Add the session ID injector interceptors first so
// that the session ID is available in the context of
// all interceptors that come after.
allOpts := []grpc.ServerOption{
addSessionIDToStreamCtx(id),
addSessionIDToUnaryCtx(id),
}

allOpts = append(allOpts, cfg.grpcOptions...)
allOpts = append(allOpts, opts...)

// Construct the gRPC server with the options.
grpcServer := grpc.NewServer(allOpts...)

// Register various grpc servers with the LNC session
// server.
cfg.registerGrpcServers(grpcServer)

return grpcServer
Expand All @@ -94,6 +108,62 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
}, nil
}

// wrappedServerStream is a wrapper around the grpc.ServerStream that allows us
// to set a custom context. This is needed since the stream handler function
// doesn't take a context as an argument, but rather has a Context method on the
// handler itself. So we use this custom wrapper to override this method.
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}

// Context returns the context of the stream.
//
// NOTE: This implements the grpc.ServerStream Context method.
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

// addSessionIDToStreamCtx is a gRPC stream interceptor that adds the given
// session ID to the context of the stream. This allows us to access the
// session ID later on for any gRPC calls made through this stream.
func addSessionIDToStreamCtx(id session.ID) grpc.ServerOption {
return grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {

md, _ := metadata.FromIncomingContext(ss.Context())
mdCopy := md.Copy()
session.AddToGRPCMetadata(mdCopy, id)

// Wrap the original stream with our custom context.
wrapped := &wrappedServerStream{
ServerStream: ss,
ctx: metadata.NewIncomingContext(
ss.Context(), mdCopy,
),
}

return handler(srv, wrapped)
})
}

// addSessionIDToUnaryCtx is a gRPC unary interceptor that adds the given
// session ID to the context of the unary call. This allows us to access the
// session ID later on for any gRPC calls made through this context.
func addSessionIDToUnaryCtx(id session.ID) grpc.ServerOption {
return grpc.UnaryInterceptor(func(ctx context.Context, req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp any, err error) {

md, _ := metadata.FromIncomingContext(ctx)
mdCopy := md.Copy()
session.AddToGRPCMetadata(mdCopy, id)

return handler(metadata.NewIncomingContext(ctx, mdCopy), req)
})
}

// start all the components necessary for the sessionRpcServer to start serving
// requests. This includes resuming all non-revoked sessions.
func (s *sessionRpcServer) start(ctx context.Context) error {
Expand Down
Loading