Skip to content

Commit cd4bed8

Browse files
authored
feat: add ConnectionType(ctx) for called methods to use (#118)
1 parent 61205be commit cd4bed8

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

rpc_test.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import (
2727
)
2828

2929
func init() {
30-
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
31-
panic(err)
30+
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
31+
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
32+
panic(err)
33+
}
3234
}
3335

3436
debugTrace = true
@@ -497,15 +499,17 @@ func TestParallelRPC(t *testing.T) {
497499
type CtxHandler struct {
498500
lk sync.Mutex
499501

500-
cancelled bool
501-
i int
502+
cancelled bool
503+
i int
504+
connectionType ConnectionType
502505
}
503506

504507
func (h *CtxHandler) Test(ctx context.Context) {
505508
h.lk.Lock()
506509
defer h.lk.Unlock()
507510
timeout := time.After(300 * time.Millisecond)
508511
h.i++
512+
h.connectionType = GetConnectionType(ctx)
509513

510514
select {
511515
case <-timeout:
@@ -543,6 +547,9 @@ func TestCtx(t *testing.T) {
543547
if !serverHandler.cancelled {
544548
t.Error("expected cancellation on the server side")
545549
}
550+
if serverHandler.connectionType != ConnectionTypeWS {
551+
t.Error("wrong connection type")
552+
}
546553

547554
serverHandler.cancelled = false
548555

@@ -564,6 +571,9 @@ func TestCtx(t *testing.T) {
564571
if serverHandler.cancelled || serverHandler.i != 2 {
565572
t.Error("wrong serverHandler state")
566573
}
574+
if serverHandler.connectionType != ConnectionTypeWS {
575+
t.Error("wrong connection type")
576+
}
567577

568578
serverHandler.lk.Unlock()
569579
closer()
@@ -598,6 +608,9 @@ func TestCtxHttp(t *testing.T) {
598608
if !serverHandler.cancelled {
599609
t.Error("expected cancellation on the server side")
600610
}
611+
if serverHandler.connectionType != ConnectionTypeHTTP {
612+
t.Error("wrong connection type")
613+
}
601614

602615
serverHandler.cancelled = false
603616

@@ -619,6 +632,10 @@ func TestCtxHttp(t *testing.T) {
619632
if serverHandler.cancelled || serverHandler.i != 2 {
620633
t.Error("wrong serverHandler state")
621634
}
635+
// connection type should have switched to WS
636+
if serverHandler.connectionType != ConnectionTypeWS {
637+
t.Error("wrong connection type")
638+
}
622639

623640
serverHandler.lk.Unlock()
624641
closer()
@@ -1007,10 +1024,12 @@ func TestChanClientReceiveAll(t *testing.T) {
10071024
}
10081025

10091026
func TestControlChanDeadlock(t *testing.T) {
1010-
_ = logging.SetLogLevel("rpc", "error")
1011-
defer func() {
1012-
_ = logging.SetLogLevel("rpc", "debug")
1013-
}()
1027+
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
1028+
_ = logging.SetLogLevel("rpc", "error")
1029+
defer func() {
1030+
_ = logging.SetLogLevel("rpc", "DEBUG")
1031+
}()
1032+
}
10141033

10151034
for r := 0; r < 20; r++ {
10161035
testControlChanDeadlock(t)

server.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,31 @@ const (
2020
rpcInvalidParams = -32602
2121
)
2222

23+
// ConnectionType indicates the type of connection, this is set in the context and can be retrieved
24+
// with GetConnectionType.
25+
type ConnectionType string
26+
27+
const (
28+
// ConnectionTypeUnknown indicates that the connection type cannot be determined, likely because
29+
// it hasn't passed through an RPCServer.
30+
ConnectionTypeUnknown ConnectionType = "unknown"
31+
// ConnectionTypeHTTP indicates that the connection is an HTTP connection.
32+
ConnectionTypeHTTP ConnectionType = "http"
33+
// ConnectionTypeWS indicates that the connection is a WebSockets connection.
34+
ConnectionTypeWS ConnectionType = "websockets"
35+
)
36+
37+
var connectionTypeCtxKey = &struct{ name string }{"jsonrpc-connection-type"}
38+
39+
// GetConnectionType returns the connection type of the request if it was set by an RPCServer.
40+
// A connection type of ConnectionTypeUnknown means that the connection type was not set.
41+
func GetConnectionType(ctx context.Context) ConnectionType {
42+
if v := ctx.Value(connectionTypeCtxKey); v != nil {
43+
return v.(ConnectionType)
44+
}
45+
return ConnectionTypeUnknown
46+
}
47+
2348
// RPCServer provides a jsonrpc 2.0 http server handler
2449
type RPCServer struct {
2550
*handler
@@ -97,10 +122,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
97122

98123
h := strings.ToLower(r.Header.Get("Connection"))
99124
if strings.Contains(h, "upgrade") {
125+
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeWS)
100126
s.handleWS(ctx, w, r)
101127
return
102128
}
103129

130+
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeHTTP)
104131
s.handleReader(ctx, r.Body, w, rpcError)
105132
}
106133

0 commit comments

Comments
 (0)