Skip to content

Commit 86db43a

Browse files
committed
loopdb: store protocol version alongside with swaps
This commit adds the protocol version to each stored swap. This will be used to ensure that when swaps are resumed after a restart, they're correctly handled given any breaking protocol changes.
1 parent a41b7c8 commit 86db43a

7 files changed

+273
-13
lines changed

loopdb/codec.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
11
package loopdb
22

3+
import (
4+
"fmt"
5+
)
6+
37
// itob returns an 8-byte big endian representation of v.
48
func itob(v uint64) []byte {
59
b := make([]byte, 8)
610
byteOrder.PutUint64(b, v)
711
return b
812
}
13+
14+
// UnmarshalProtocolVersion attempts to unmarshal a byte slice to a
15+
// ProtocolVersion value. If the unmarshal fails, ProtocolVersionUnrecorded is
16+
// returned along with an error.
17+
func UnmarshalProtocolVersion(b []byte) (ProtocolVersion, error) {
18+
if b == nil {
19+
return ProtocolVersionUnrecorded, nil
20+
}
21+
22+
if len(b) != 4 {
23+
return ProtocolVersionUnrecorded,
24+
fmt.Errorf("invalid size: %v", len(b))
25+
}
26+
27+
version := ProtocolVersion(byteOrder.Uint32(b))
28+
if !version.Valid() {
29+
return ProtocolVersionUnrecorded,
30+
fmt.Errorf("invalid protocol version: %v", version)
31+
}
32+
33+
return version, nil
34+
}
35+
36+
// MarshalProtocolVersion marshals a ProtocolVersion value to a byte slice.
37+
func MarshalProtocolVersion(version ProtocolVersion) []byte {
38+
var versionBytes [4]byte
39+
byteOrder.PutUint32(versionBytes[:], uint32(version))
40+
41+
return versionBytes[:]
42+
}

loopdb/codec_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package loopdb
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
// TestProtocolVersionMarshalUnMarshal tests that marshalling and unmarshalling
10+
// looprpc.ProtocolVersion works correctly.
11+
func TestProtocolVersionMarshalUnMarshal(t *testing.T) {
12+
t.Parallel()
13+
14+
testVersions := [...]ProtocolVersion{
15+
ProtocolVersionLegacy,
16+
ProtocolVersionMultiLoopOut,
17+
ProtocolVersionSegwitLoopIn,
18+
ProtocolVersionPreimagePush,
19+
ProtocolVersionUserExpiryLoopOut,
20+
}
21+
22+
bogusVersion := []byte{0xFF, 0xFF, 0xFF, 0xFF}
23+
invalidSlice := []byte{0xFF, 0xFF, 0xFF}
24+
25+
for i := 0; i < len(testVersions); i++ {
26+
testVersion := testVersions[i]
27+
28+
// Test that unmarshal(marshal(v)) == v.
29+
version, err := UnmarshalProtocolVersion(
30+
MarshalProtocolVersion(testVersion),
31+
)
32+
require.NoError(t, err)
33+
require.Equal(t, testVersion, version)
34+
35+
// Test that unmarshalling a nil slice returns the default
36+
// version along with no error.
37+
version, err = UnmarshalProtocolVersion(nil)
38+
require.NoError(t, err)
39+
require.Equal(t, ProtocolVersionUnrecorded, version)
40+
41+
// Test that unmarshalling an unknown version returns the
42+
// default version along with an error.
43+
version, err = UnmarshalProtocolVersion(bogusVersion)
44+
require.Error(t, err, "expected invalid version")
45+
require.Equal(t, ProtocolVersionUnrecorded, version)
46+
47+
// Test that unmarshalling an invalid slice returns the
48+
// default version along with an error.
49+
version, err = UnmarshalProtocolVersion(invalidSlice)
50+
require.Error(t, err, "expected invalid size")
51+
require.Equal(t, ProtocolVersionUnrecorded, version)
52+
}
53+
}

loopdb/loop.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ type SwapContract struct {
4646

4747
// Label contains an optional label for the swap.
4848
Label string
49+
50+
// ProtocolVersion stores the protocol version when the swap was
51+
// created.
52+
ProtocolVersion ProtocolVersion
4953
}
5054

5155
// Loop contains fields shared between LoopIn and LoopOut

loopdb/protocol_version.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package loopdb
2+
3+
import (
4+
"math"
5+
6+
"github.com/lightninglabs/loop/looprpc"
7+
)
8+
9+
// ProtocolVersion represents the protocol version (declared on rpc level) that
10+
// the client declared to us.
11+
type ProtocolVersion uint32
12+
13+
const (
14+
// ProtocolVersionLegacy indicates that the client is a legacy version
15+
// that did not report its protocol version.
16+
ProtocolVersionLegacy ProtocolVersion = 0
17+
18+
// ProtocolVersionMultiLoopOut indicates that the client supports multi
19+
// loop out.
20+
ProtocolVersionMultiLoopOut ProtocolVersion = 1
21+
22+
// ProtocolVersionSegwitLoopIn indicates that the client supports segwit
23+
// loop in.
24+
ProtocolVersionSegwitLoopIn ProtocolVersion = 2
25+
26+
// ProtocolVersionPreimagePush indicates that the client will push loop
27+
// out preimages to the sever to speed up claim.
28+
ProtocolVersionPreimagePush ProtocolVersion = 3
29+
30+
// ProtocolVersionUserExpiryLoopOut indicates that the client will
31+
// propose a cltv expiry height for loop out.
32+
ProtocolVersionUserExpiryLoopOut ProtocolVersion = 4
33+
34+
// ProtocolVersionUnrecorded is set for swaps were created before we
35+
// started saving protocol version with swaps.
36+
ProtocolVersionUnrecorded ProtocolVersion = math.MaxUint32
37+
38+
// CurrentRpcProtocolVersion defines the version of the RPC protocol
39+
// that is currently supported by the loop client.
40+
CurrentRPCProtocolVersion = looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT
41+
42+
// CurrentInteranlProtocolVersionInternal defines the RPC current
43+
// protocol in the internal representation.
44+
CurrentInternalProtocolVersion = ProtocolVersion(CurrentRPCProtocolVersion)
45+
)
46+
47+
// Valid returns true if the value of the ProtocolVersion is valid.
48+
func (p ProtocolVersion) Valid() bool {
49+
return p <= CurrentInternalProtocolVersion
50+
}
51+
52+
// String returns the string representation of a protocol version.
53+
func (p ProtocolVersion) String() string {
54+
switch p {
55+
case ProtocolVersionUnrecorded:
56+
return "Unrecorded"
57+
58+
case ProtocolVersionLegacy:
59+
return "Legacy"
60+
61+
case ProtocolVersionMultiLoopOut:
62+
return "Multi Loop Out"
63+
64+
case ProtocolVersionSegwitLoopIn:
65+
return "Segwit Loop In"
66+
67+
case ProtocolVersionPreimagePush:
68+
return "Preimage Push"
69+
70+
case ProtocolVersionUserExpiryLoopOut:
71+
return "User Expiry Loop Out"
72+
73+
default:
74+
return "Unknown"
75+
}
76+
}

loopdb/protocol_version_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package loopdb
2+
3+
import (
4+
"testing"
5+
6+
"github.com/lightninglabs/loop/looprpc"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
// TestProtocolVersionSanity tests that protocol versions are sane, meaning
11+
// we always keep our stored protocol version in sync with the RPC protocol
12+
// version except for the unrecorded version.
13+
func TestProtocolVersionSanity(t *testing.T) {
14+
t.Parallel()
15+
16+
versions := [...]ProtocolVersion{
17+
ProtocolVersionLegacy,
18+
ProtocolVersionMultiLoopOut,
19+
ProtocolVersionSegwitLoopIn,
20+
ProtocolVersionPreimagePush,
21+
ProtocolVersionUserExpiryLoopOut,
22+
}
23+
24+
rpcVersions := [...]looprpc.ProtocolVersion{
25+
looprpc.ProtocolVersion_LEGACY,
26+
looprpc.ProtocolVersion_MULTI_LOOP_OUT,
27+
looprpc.ProtocolVersion_NATIVE_SEGWIT_LOOP_IN,
28+
looprpc.ProtocolVersion_PREIMAGE_PUSH_LOOP_OUT,
29+
looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT,
30+
}
31+
32+
require.Equal(t, len(versions), len(rpcVersions))
33+
for i, version := range versions {
34+
require.Equal(t, uint32(version), uint32(rpcVersions[i]))
35+
}
36+
37+
// Finally test that the current version contants are up to date
38+
require.Equal(t,
39+
CurrentInternalProtocolVersion,
40+
versions[len(versions)-1],
41+
)
42+
43+
require.Equal(t,
44+
uint32(CurrentInternalProtocolVersion),
45+
uint32(CurrentRPCProtocolVersion),
46+
)
47+
}

loopdb/store.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ var (
6969
// value: string label
7070
labelKey = []byte("label")
7171

72+
// protocolVersionKey is used to optionally store the protocol version
73+
// for the serialized swap contract. It is nested within the sub-bucket
74+
// for each active swap.
75+
//
76+
// path: loopInBucket/loopOutBucket -> swapBucket[hash] -> protocolVersionKey
77+
//
78+
// value: protocol version as specified in server.proto
79+
protocolVersionKey = []byte("protocol-version")
80+
7281
// outgoingChanSetKey is the key that stores a list of channel ids that
7382
// restrict the loop out swap payment.
7483
//
@@ -276,6 +285,18 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
276285
return err
277286
}
278287

288+
// Try to unmarshal the protocol version for the swap.
289+
// If the protocol version is not stored (which is
290+
// the case for old clients), we'll assume the
291+
// ProtocolVersionUnrecorded instead.
292+
contract.ProtocolVersion, err =
293+
UnmarshalProtocolVersion(
294+
swapBucket.Get(protocolVersionKey),
295+
)
296+
if err != nil {
297+
return err
298+
}
299+
279300
loop := LoopOut{
280301
Loop: Loop{
281302
Events: updates,
@@ -401,6 +422,18 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) {
401422
return err
402423
}
403424

425+
// Try to unmarshal the protocol version for the swap.
426+
// If the protocol version is not stored (which is
427+
// the case for old clients), we'll assume the
428+
// ProtocolVersionUnrecorded instead.
429+
contract.ProtocolVersion, err =
430+
UnmarshalProtocolVersion(
431+
swapBucket.Get(protocolVersionKey),
432+
)
433+
if err != nil {
434+
return err
435+
}
436+
404437
loop := LoopIn{
405438
Loop: Loop{
406439
Events: updates,
@@ -512,6 +545,14 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash,
512545
return err
513546
}
514547

548+
// Store the current protocol version.
549+
err = swapBucket.Put(protocolVersionKey,
550+
MarshalProtocolVersion(swap.ProtocolVersion),
551+
)
552+
if err != nil {
553+
return err
554+
}
555+
515556
// Finally, we'll create an empty updates bucket for this swap
516557
// to track any future updates to the swap itself.
517558
_, err = swapBucket.CreateBucket(updatesBucketKey)
@@ -550,6 +591,14 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash,
550591
return err
551592
}
552593

594+
// Store the current protocol version.
595+
err = swapBucket.Put(protocolVersionKey,
596+
MarshalProtocolVersion(swap.ProtocolVersion),
597+
)
598+
if err != nil {
599+
return err
600+
}
601+
553602
// Write label to disk if we have one.
554603
if err := putLabel(swapBucket, swap.Label); err != nil {
555604
return err

0 commit comments

Comments
 (0)