Skip to content

Commit bdb4b77

Browse files
committed
swap: refactor htlc construction to allow passing of internal keys
This commit is a refactor of how we construct htlcs to make it possible to pass in internal keys for the sender and receiver when creating P2TR htlcs. Furthermore the commit also cleans up constructors to not pass in script versions and output types to make the code more readable.
1 parent 35e0120 commit bdb4b77

File tree

9 files changed

+207
-177
lines changed

9 files changed

+207
-177
lines changed

client.go

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -213,29 +213,24 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
213213
SwapHash: swp.Hash,
214214
LastUpdate: swp.LastUpdateTime(),
215215
}
216-
scriptVersion := GetHtlcScriptVersion(
217-
swp.Contract.ProtocolVersion,
218-
)
219-
220-
outputType := swap.HtlcP2WSH
221-
if scriptVersion == swap.HtlcV3 {
222-
outputType = swap.HtlcP2TR
223-
}
224216

225-
htlc, err := swap.NewHtlc(
226-
scriptVersion,
227-
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
228-
swp.Contract.ReceiverKey, swp.Hash,
229-
outputType, s.lndServices.ChainParams,
217+
htlc, err := GetHtlc(
218+
swp.Hash, &swp.Contract.SwapContract,
219+
s.lndServices.ChainParams,
230220
)
231221
if err != nil {
232222
return nil, err
233223
}
234224

235-
if outputType == swap.HtlcP2TR {
236-
swapInfo.HtlcAddressP2TR = htlc.Address
237-
} else {
225+
switch htlc.OutputType {
226+
case swap.HtlcP2WSH:
238227
swapInfo.HtlcAddressP2WSH = htlc.Address
228+
229+
case swap.HtlcP2TR:
230+
swapInfo.HtlcAddressP2TR = htlc.Address
231+
232+
default:
233+
return nil, swap.ErrInvalidOutputType
239234
}
240235

241236
swaps = append(swaps, swapInfo)
@@ -250,34 +245,23 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
250245
LastUpdate: swp.LastUpdateTime(),
251246
}
252247

253-
scriptVersion := GetHtlcScriptVersion(
254-
swp.Contract.SwapContract.ProtocolVersion,
248+
htlc, err := GetHtlc(
249+
swp.Hash, &swp.Contract.SwapContract,
250+
s.lndServices.ChainParams,
255251
)
252+
if err != nil {
253+
return nil, err
254+
}
256255

257-
if scriptVersion == swap.HtlcV3 {
258-
htlcP2TR, err := swap.NewHtlc(
259-
swap.HtlcV3, swp.Contract.CltvExpiry,
260-
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
261-
swp.Hash, swap.HtlcP2TR,
262-
s.lndServices.ChainParams,
263-
)
264-
if err != nil {
265-
return nil, err
266-
}
256+
switch htlc.OutputType {
257+
case swap.HtlcP2WSH:
258+
swapInfo.HtlcAddressP2WSH = htlc.Address
267259

268-
swapInfo.HtlcAddressP2TR = htlcP2TR.Address
269-
} else {
270-
htlcP2WSH, err := swap.NewHtlc(
271-
swap.HtlcV2, swp.Contract.CltvExpiry,
272-
swp.Contract.SenderKey, swp.Contract.ReceiverKey,
273-
swp.Hash, swap.HtlcP2WSH,
274-
s.lndServices.ChainParams,
275-
)
276-
if err != nil {
277-
return nil, err
278-
}
260+
case swap.HtlcP2TR:
261+
swapInfo.HtlcAddressP2TR = htlc.Address
279262

280-
swapInfo.HtlcAddressP2WSH = htlcP2WSH.Address
263+
default:
264+
return nil, swap.ErrInvalidOutputType
281265
}
282266

283267
swaps = append(swaps, swapInfo)

client_test.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,26 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
284284

285285
// Assert that the loopout htlc equals to the expected one.
286286
scriptVersion := GetHtlcScriptVersion(protocolVersion)
287+
var htlc *swap.Htlc
287288

288-
outputType := swap.HtlcP2TR
289-
if scriptVersion != swap.HtlcV3 {
290-
outputType = swap.HtlcP2WSH
289+
switch scriptVersion {
290+
case swap.HtlcV2:
291+
htlc, err = swap.NewHtlcV2(
292+
pendingSwap.Contract.CltvExpiry, senderKey,
293+
receiverKey, hash, &chaincfg.TestNet3Params,
294+
)
295+
296+
case swap.HtlcV3:
297+
htlc, err = swap.NewHtlcV3(
298+
pendingSwap.Contract.CltvExpiry, senderKey,
299+
receiverKey, senderKey, receiverKey, hash,
300+
&chaincfg.TestNet3Params,
301+
)
302+
303+
default:
304+
t.Fatalf(swap.ErrInvalidScriptVersion.Error())
291305
}
292306

293-
htlc, err := swap.NewHtlc(
294-
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
295-
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
296-
)
297307
require.NoError(t, err)
298308
require.Equal(t, htlc.PkScript, confIntent.PkScript)
299309

loopd/view.go

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"github.com/lightninglabs/lndclient"
88
"github.com/lightninglabs/loop"
99
"github.com/lightninglabs/loop/loopdb"
10-
"github.com/lightninglabs/loop/swap"
1110
)
1211

1312
// view prints all swaps currently in the database.
@@ -49,24 +48,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
4948
}
5049

5150
for _, s := range swaps {
52-
scriptVersion := loop.GetHtlcScriptVersion(
53-
s.Contract.ProtocolVersion,
54-
)
55-
56-
var outputType swap.HtlcOutputType
57-
switch scriptVersion {
58-
case swap.HtlcV2:
59-
outputType = swap.HtlcP2WSH
60-
61-
case swap.HtlcV3:
62-
outputType = swap.HtlcP2TR
63-
}
64-
htlc, err := swap.NewHtlc(
65-
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
66-
s.Contract.CltvExpiry,
67-
s.Contract.SenderKey,
68-
s.Contract.ReceiverKey,
69-
s.Hash, outputType, chainParams,
51+
htlc, err := loop.GetHtlc(
52+
s.Hash, &s.Contract.SwapContract, chainParams,
7053
)
7154
if err != nil {
7255
return err
@@ -77,7 +60,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
7760
s.Contract.InitiationTime, s.Contract.InitiationHeight,
7861
)
7962
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
80-
fmt.Printf(" Htlc address: %v\n", htlc.Address)
63+
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
64+
htlc.Address)
8165

8266
fmt.Printf(" Uncharge channels: %v\n",
8367
s.Contract.OutgoingChanSet)
@@ -113,12 +97,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
11397
}
11498

11599
for _, s := range swaps {
116-
htlc, err := swap.NewHtlc(
117-
loop.GetHtlcScriptVersion(s.Contract.ProtocolVersion),
118-
s.Contract.CltvExpiry,
119-
s.Contract.SenderKey,
120-
s.Contract.ReceiverKey,
121-
s.Hash, swap.HtlcP2WSH, chainParams,
100+
htlc, err := loop.GetHtlc(
101+
s.Hash, &s.Contract.SwapContract, chainParams,
122102
)
123103
if err != nil {
124104
return err
@@ -129,7 +109,8 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
129109
s.Contract.InitiationTime, s.Contract.InitiationHeight,
130110
)
131111
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
132-
fmt.Printf(" Htlc address: %v\n", htlc.Address)
112+
fmt.Printf(" Htlc address (%s): %v\n", htlc.OutputType,
113+
htlc.Address)
133114
fmt.Printf(" Amt: %v, Expiry: %v\n",
134115
s.Contract.AmountRequested, s.Contract.CltvExpiry,
135116
)

loopin.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -404,26 +404,26 @@ func validateLoopInContract(lnd *lndclient.LndServices,
404404
// initHtlcs creates and updates the native and nested segwit htlcs
405405
// of the loopInSwap.
406406
func (s *loopInSwap) initHtlcs() error {
407-
if IsTaprootSwap(&s.SwapContract) {
408-
htlcP2TR, err := s.swapKit.getHtlc(swap.HtlcP2TR)
409-
if err != nil {
410-
return err
411-
}
407+
htlc, err := GetHtlc(
408+
s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams,
409+
)
410+
if err != nil {
411+
return err
412+
}
412413

413-
s.swapKit.log.Infof("Htlc address (P2TR): %v", htlcP2TR.Address)
414-
s.htlcP2TR = htlcP2TR
414+
switch htlc.OutputType {
415+
case swap.HtlcP2WSH:
416+
s.htlcP2WSH = htlc
415417

416-
return nil
417-
}
418+
case swap.HtlcP2TR:
419+
s.htlcP2TR = htlc
418420

419-
htlcP2WSH, err := s.swapKit.getHtlc(swap.HtlcP2WSH)
420-
if err != nil {
421-
return err
421+
default:
422+
return fmt.Errorf("invalid output type")
422423
}
423424

424-
// Log htlc addresses for debugging.
425-
s.swapKit.log.Infof("Htlc address (P2WSH): %v", htlcP2WSH.Address)
426-
s.htlcP2WSH = htlcP2WSH
425+
s.swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
426+
htlc.Address)
427427

428428
return nil
429429
}

loopin_test.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -455,21 +455,32 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
455455
pendSwap.Loop.Events[0].Cost = cost
456456
}
457457

458-
scriptVersion := GetHtlcScriptVersion(storedVersion)
458+
var (
459+
htlc *swap.Htlc
460+
err error
461+
)
459462

460-
outputType := swap.HtlcP2WSH
461-
if scriptVersion == swap.HtlcV3 {
462-
outputType = swap.HtlcP2TR
463+
switch GetHtlcScriptVersion(storedVersion) {
464+
case swap.HtlcV2:
465+
htlc, err = swap.NewHtlcV2(
466+
contract.CltvExpiry, contract.SenderKey,
467+
contract.ReceiverKey, testPreimage.Hash(),
468+
cfg.lnd.ChainParams,
469+
)
470+
471+
case swap.HtlcV3:
472+
htlc, err = swap.NewHtlcV3(
473+
contract.CltvExpiry, contract.SenderKey,
474+
contract.ReceiverKey, contract.SenderKey,
475+
contract.ReceiverKey, testPreimage.Hash(),
476+
cfg.lnd.ChainParams,
477+
)
478+
479+
default:
480+
t.Fatalf("unknown HTLC script version")
463481
}
464482

465-
htlc, err := swap.NewHtlc(
466-
scriptVersion, contract.CltvExpiry, contract.SenderKey,
467-
contract.ReceiverKey, testPreimage.Hash(), outputType,
468-
cfg.lnd.ChainParams,
469-
)
470-
if err != nil {
471-
t.Fatal(err)
472-
}
483+
require.NoError(t, err)
473484

474485
err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract)
475486
if err != nil {

loopout.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -201,21 +201,17 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
201201

202202
swapKit.lastUpdateTime = initiationTime
203203

204-
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
205-
outputType := swap.HtlcP2TR
206-
if scriptVersion != swap.HtlcV3 {
207-
// Default to using P2WSH for legacy htlcs.
208-
outputType = swap.HtlcP2WSH
209-
}
210-
211204
// Create the htlc.
212-
htlc, err := swapKit.getHtlc(outputType)
205+
htlc, err := GetHtlc(
206+
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
207+
)
213208
if err != nil {
214209
return nil, err
215210
}
216211

217212
// Log htlc address for debugging.
218-
swapKit.log.Infof("Htlc address: %v", htlc.Address)
213+
swapKit.log.Infof("Htlc address (%s): %v", htlc.OutputType,
214+
htlc.Address)
219215

220216
// Obtain the payment addr since we'll need it later for routing plugin
221217
// recommendation and possibly for cancel.
@@ -261,15 +257,10 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
261257
hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
262258
)
263259

264-
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
265-
outputType := swap.HtlcP2TR
266-
if scriptVersion != swap.HtlcV3 {
267-
// Default to using P2WSH for legacy htlcs.
268-
outputType = swap.HtlcP2WSH
269-
}
270-
271260
// Create the htlc.
272-
htlc, err := swapKit.getHtlc(outputType)
261+
htlc, err := GetHtlc(
262+
swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams,
263+
)
273264
if err != nil {
274265
return nil, err
275266
}

swap.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"time"
66

7+
"github.com/btcsuite/btcd/chaincfg"
78
"github.com/lightninglabs/lndclient"
89
"github.com/lightninglabs/loop/loopdb"
910
"github.com/lightninglabs/loop/swap"
@@ -67,14 +68,28 @@ func IsTaprootSwap(swapContract *loopdb.SwapContract) bool {
6768
return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3
6869
}
6970

70-
// getHtlc composes and returns the on-chain swap script.
71-
func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) {
72-
return swap.NewHtlc(
73-
GetHtlcScriptVersion(s.contract.ProtocolVersion),
74-
s.contract.CltvExpiry, s.contract.SenderKey,
75-
s.contract.ReceiverKey, s.hash, outputType,
76-
s.swapConfig.lnd.ChainParams,
77-
)
71+
// GetHtlc composes and returns the on-chain swap script.
72+
func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract,
73+
chainParams *chaincfg.Params) (*swap.Htlc, error) {
74+
75+
switch GetHtlcScriptVersion(contract.ProtocolVersion) {
76+
case swap.HtlcV2:
77+
return swap.NewHtlcV2(
78+
contract.CltvExpiry, contract.SenderKey,
79+
contract.ReceiverKey, hash,
80+
chainParams,
81+
)
82+
83+
case swap.HtlcV3:
84+
return swap.NewHtlcV3(
85+
contract.CltvExpiry, contract.SenderKey,
86+
contract.ReceiverKey, contract.SenderKey,
87+
contract.ReceiverKey, hash,
88+
chainParams,
89+
)
90+
}
91+
92+
return nil, swap.ErrInvalidScriptVersion
7893
}
7994

8095
// swapInfo constructs and returns a filled SwapInfo from

0 commit comments

Comments
 (0)