Skip to content

Commit 446f163

Browse files
authored
Merge pull request #541 from bhandras/htlc-v3-interal-key
swap: refactor htlc construction to allow passing of internal keys
2 parents 35e0120 + 049b17f commit 446f163

16 files changed

+306
-369
lines changed

client.go

+24-40
Original file line numberDiff line numberDiff line change
@@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
213213
SwapHash: swp.Hash,
214214
LastUpdate: swp.LastUpdateTime(),
215215
}
216-
scriptVersion := GetHtlcScriptVersion(
217-
swp.Contract.ProtocolVersion,
218-
)
219-
220-
outputType := swap.HtlcP2WSH
221-
if scriptVersion == swap.HtlcV3 {
222-
outputType = swap.HtlcP2TR
223-
}
224216

225-
htlc, err := swap.NewHtlc(
226-
scriptVersion,
227-
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
228-
swp.Contract.ReceiverKey, swp.Hash,
229-
outputType, s.lndServices.ChainParams,
217+
htlc, err := GetHtlc(
218+
swp.Hash, &swp.Contract.SwapContract,
219+
s.lndServices.ChainParams,
230220
)
231221
if err != nil {
232222
return nil, err
233223
}
234224

235-
if outputType == swap.HtlcP2TR {
236-
swapInfo.HtlcAddressP2TR = htlc.Address
237-
} else {
225+
switch htlc.OutputType {
226+
case swap.HtlcP2WSH:
238227
swapInfo.HtlcAddressP2WSH = htlc.Address
228+
229+
case swap.HtlcP2TR:
230+
swapInfo.HtlcAddressP2TR = htlc.Address
231+
232+
default:
233+
return nil, swap.ErrInvalidOutputType
239234
}
240235

241236
swaps = append(swaps, swapInfo)
@@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
250245
LastUpdate: swp.LastUpdateTime(),
251246
}
252247

253-
scriptVersion := GetHtlcScriptVersion(
254-
swp.Contract.SwapContract.ProtocolVersion,
248+
htlc, err := GetHtlc(
249+
swp.Hash, &swp.Contract.SwapContract,
250+
s.lndServices.ChainParams,
255251
)
252+
if err != nil {
253+
return nil, err
254+
}
256255

257-
if scriptVersion == swap.HtlcV3 {
258-
htlcP2TR, err := swap.NewHtlc(
259-
swap.HtlcV3, swp.Contract.CltvExpiry,
260-
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
261-
swp.Hash, swap.HtlcP2TR,
262-
s.lndServices.ChainParams,
263-
)
264-
if err != nil {
265-
return nil, err
266-
}
256+
switch htlc.OutputType {
257+
case swap.HtlcP2WSH:
258+
swapInfo.HtlcAddressP2WSH = htlc.Address
267259

268-
swapInfo.HtlcAddressP2TR = htlcP2TR.Address
269-
} else {
270-
htlcP2WSH, err := swap.NewHtlc(
271-
swap.HtlcV2, swp.Contract.CltvExpiry,
272-
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
273-
swp.Hash, swap.HtlcP2WSH,
274-
s.lndServices.ChainParams,
275-
)
276-
if err != nil {
277-
return nil, err
278-
}
260+
case swap.HtlcP2TR:
261+
swapInfo.HtlcAddressP2TR = htlc.Address
279262

280-
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address
263+
default:
264+
return nil, swap.ErrInvalidOutputType
281265
}
282266

283267
swaps = append(swaps, swapInfo)

client_test.go

+27-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package loop
22

33
import (
4-
"bytes"
54
"context"
65
"crypto/sha256"
76
"errors"
@@ -57,9 +56,7 @@ func TestLoopOutSuccess(t *testing.T) {
5756

5857
// Initiate loop out.
5958
info, err := ctx.swapClient.LoopOut(context.Background(), &req)
60-
if err != nil {
61-
t.Fatal(err)
62-
}
59+
require.NoError(t, err)
6360

6461
ctx.assertStored()
6562
ctx.assertStatus(loopdb.StateInitiated)
@@ -84,9 +81,7 @@ func TestLoopOutFailOffchain(t *testing.T) {
8481
ctx := createClientTestContext(t, nil)
8582

8683
_, err := ctx.swapClient.LoopOut(context.Background(), testRequest)
87-
if err != nil {
88-
t.Fatal(err)
89-
}
84+
require.NoError(t, err)
9085

9186
ctx.assertStored()
9287
ctx.assertStatus(loopdb.StateInitiated)
@@ -208,14 +203,10 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
208203
amt := btcutil.Amount(50000)
209204

210205
swapPayReq, err := getInvoice(hash, amt, swapInvoiceDesc)
211-
if err != nil {
212-
t.Fatal(err)
213-
}
206+
require.NoError(t, err)
214207

215208
prePayReq, err := getInvoice(hash, 100, prepayInvoiceDesc)
216-
if err != nil {
217-
t.Fatal(err)
218-
}
209+
require.NoError(t, err)
219210

220211
_, senderPubKey := test.CreateKey(1)
221212
var senderKey [33]byte
@@ -284,16 +275,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
284275

285276
// Assert that the loopout htlc equals to the expected one.
286277
scriptVersion := GetHtlcScriptVersion(protocolVersion)
278+
var htlc *swap.Htlc
287279

288-
outputType := swap.HtlcP2TR
289-
if scriptVersion != swap.HtlcV3 {
290-
outputType = swap.HtlcP2WSH
280+
switch scriptVersion {
281+
case swap.HtlcV2:
282+
htlc, err = swap.NewHtlcV2(
283+
pendingSwap.Contract.CltvExpiry, senderKey,
284+
receiverKey, hash, &chaincfg.TestNet3Params,
285+
)
286+
287+
case swap.HtlcV3:
288+
htlc, err = swap.NewHtlcV3(
289+
pendingSwap.Contract.CltvExpiry, senderKey,
290+
receiverKey, senderKey, receiverKey, hash,
291+
&chaincfg.TestNet3Params,
292+
)
293+
294+
default:
295+
t.Fatalf(swap.ErrInvalidScriptVersion.Error())
291296
}
292297

293-
htlc, err := swap.NewHtlc(
294-
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
295-
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
296-
)
297298
require.NoError(t, err)
298299
require.Equal(t, htlc.PkScript, confIntent.PkScript)
299300

@@ -363,10 +364,11 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
363364
// Expect client on-chain sweep of HTLC.
364365
sweepTx := ctx.ReceiveTx()
365366

366-
if !bytes.Equal(sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
367-
htlcOutpoint.Hash[:]) {
368-
ctx.T.Fatalf("client not sweeping from htlc tx")
369-
}
367+
require.Equal(
368+
ctx.T, htlcOutpoint.Hash[:],
369+
sweepTx.TxIn[0].PreviousOutPoint.Hash[:],
370+
"client not sweeping from htlc tx",
371+
)
370372

371373
var preImageIndex int
372374
switch scriptVersion {
@@ -380,9 +382,7 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
380382
// Check preimage.
381383
clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex]
382384
clientPreImageHash := sha256.Sum256(clientPreImage)
383-
if clientPreImageHash != hash {
384-
ctx.T.Fatalf("incorrect preimage")
385-
}
385+
require.Equal(ctx.T, hash, lntypes.Hash(clientPreImageHash))
386386

387387
// Since we successfully published our sweep, we expect the preimage to
388388
// have been pushed to our mock server.

loopd/swapclient_server_test.go

+10-16
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,13 @@ func TestValidateConfTarget(t *testing.T) {
130130
test.confTarget, defaultConf,
131131
)
132132

133-
haveErr := err != nil
134-
if haveErr != test.expectErr {
135-
t.Fatalf("expected err: %v, got: %v",
136-
test.expectErr, err)
133+
if test.expectErr {
134+
require.Error(t, err)
135+
} else {
136+
require.NoError(t, err)
137137
}
138138

139-
if target != test.expectedTarget {
140-
t.Fatalf("expected: %v, got: %v",
141-
test.expectedTarget, target)
142-
}
139+
require.Equal(t, test.expectedTarget, target)
143140
})
144141
}
145142
}
@@ -199,16 +196,13 @@ func TestValidateLoopInRequest(t *testing.T) {
199196
test.confTarget, external,
200197
)
201198

202-
haveErr := err != nil
203-
if haveErr != test.expectErr {
204-
t.Fatalf("expected err: %v, got: %v",
205-
test.expectErr, err)
199+
if test.expectErr {
200+
require.Error(t, err)
201+
} else {
202+
require.NoError(t, err)
206203
}
207204

208-
if conf != test.expectedTarget {
209-
t.Fatalf("expected: %v, got: %v",
210-
test.expectedTarget, conf)
211-
}
205+
require.Equal(t, test.expectedTarget, conf)
212206
})
213207
}
214208
}

loopd/view.go

+8-27
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"github.com/lightninglabs/lndclient"
88
"github.com/lightninglabs/loop"
99
"github.com/lightninglabs/loop/loopdb"
10-
"github.com/lightninglabs/loop/swap"
1110
)
1211

1312
// view prints all swaps currently in the database.
@@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
4948
}
5049

5150
for _, s := range swaps {
52-
scriptVersion := loop.GetHtlcScriptVersion(
53-
s.Contract.ProtocolVersion,
54-
)
55-
56-
var outputType swap.HtlcOutputType
57-
switch scriptVersion {
58-
case swap.HtlcV2:
59-
outputType = swap.HtlcP2WSH
60-
61-
case swap.HtlcV3:
62-
outputType = swap.HtlcP2TR
63-
}
64-
htlc, err := swap.NewHtlc(
65-
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
66-
s.Contract.CltvExpiry,
67-
s.Contract.SenderKey,
68-
s.Contract.ReceiverKey,
69-
s.Hash, outputType, chainParams,
51+
htlc, err := loop.GetHtlc(
52+
s.Hash, &s.Contract.SwapContract, chainParams,
7053
)
7154
if err != nil {
7255
return err
@@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
7760
s.Contract.InitiationTime, s.Contract.InitiationHeight,
7861
)
7962
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
80-
fmt.Printf(" Htlc address: %v\n", htlc.Address)
63+
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
64+
htlc.Address)
8165

8266
fmt.Printf(" Uncharge channels: %v\n",
8367
s.Contract.OutgoingChanSet)
@@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
11397
}
11498

11599
for _, s := range swaps {
116-
htlc, err := swap.NewHtlc(
117-
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
118-
s.Contract.CltvExpiry,
119-
s.Contract.SenderKey,
120-
s.Contract.ReceiverKey,
121-
s.Hash, swap.HtlcP2WSH, chainParams,
100+
htlc, err := loop.GetHtlc(
101+
s.Hash, &s.Contract.SwapContract, chainParams,
122102
)
123103
if err != nil {
124104
return err
@@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
129109
s.Contract.InitiationTime, s.Contract.InitiationHeight,
130110
)
131111
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
132-
fmt.Printf(" Htlc address: %v\n", htlc.Address)
112+
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
113+
htlc.Address)
133114
fmt.Printf(" Amt: %v, Expiry: %v\n",
134115
s.Contract.AmountRequested, s.Contract.CltvExpiry,
135116
)

loopin.go

+15-15
Original file line numberDiff line numberDiff line change
@@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices,
404404
// initHtlcs creates and updates the native and nested segwit htlcs
405405
// of the loopInSwap.
406406
func (s *loopInSwap) initHtlcs() error {
407-
if IsTaprootSwap(&s.SwapContract) {
408-
htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR)
409-
if err != nil {
410-
return err
411-
}
407+
htlc, err := GetHtlc(
408+
s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams,
409+
)
410+
if err != nil {
411+
return err
412+
}
412413

413-
s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address)
414-
s.htlcP2TR = htlcP2TR
414+
switch htlc.OutputType {
415+
case swap.HtlcP2WSH:
416+
s.htlcP2WSH = htlc
415417

416-
return nil
417-
}
418+
case swap.HtlcP2TR:
419+
s.htlcP2TR = htlc
418420

419-
htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH)
420-
if err != nil {
421-
return err
421+
default:
422+
return fmt.Errorf("invalid output type")
422423
}
423424

424-
// Log htlc addresses for debugging.
425-
s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address)
426-
s.htlcP2WSH = htlcP2WSH
425+
s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
426+
htlc.Address)
427427

428428
return nil
429429
}

0 commit comments

Comments
 (0)