Skip to content

Commit c95a3db

Browse files
astrojarredcffls
andauthored
Merge change into already existing output when possible (#38)
* Add a copy method to TransactionOutput * Add change to TxOutput with matching address when possible * Update Tx Output addition method * Update outputs and fees of relevant tests * Lint the changes * Update the copy method of TransactionOutput to __copy__ * Add merge change as an optional parameter * Restore original tests before merge_change was default * Do not merge change if there are multiple change UTxOs * Lint new changes * Remove unused imports * Add unit tests and minor fixes for change merging Co-authored-by: Jarred <[email protected]>, Jerry <[email protected]>
1 parent 131816e commit c95a3db

File tree

3 files changed

+159
-10
lines changed

3 files changed

+159
-10
lines changed

Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ help:
2727
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
2828

2929
cov: ## check code coverage
30-
poetry run coverage run --source pycardano -m pytest -n 4
31-
poetry run coverage report -m
30+
poetry run pytest -n 4 --cov pycardano
3231

3332
cov-html: cov ## check code coverage and generate an html report
3433
poetry run coverage html -d cov_html

pycardano/txbuilder.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ def _calc_change(
312312
# when there is only ADA left, simply use remaining coin value as change
313313
if not change.multi_asset:
314314
if change.coin < min_lovelace(change, self.context):
315-
raise InsufficientUTxOBalanceException("Not enough ADA left for change")
315+
raise InsufficientUTxOBalanceException(
316+
f"Not enough ADA left for change: {change.coin} but needs {min_lovelace(change, self.context)}"
317+
)
316318
lovelace_change = change.coin
317319
change_output_arr.append(TransactionOutput(address, lovelace_change))
318320

@@ -347,17 +349,42 @@ def _calc_change(
347349
return change_output_arr
348350

349351
def _add_change_and_fee(
350-
self, change_address: Optional[Address]
352+
self,
353+
change_address: Optional[Address],
354+
merge_change: Optional[bool] = False,
351355
) -> TransactionBuilder:
352-
original_outputs = self.outputs[:]
356+
original_outputs = deepcopy(self.outputs)
357+
change_output_index = None
358+
359+
def _merge_changes(changes):
360+
if change_output_index is not None and len(changes) == 1:
361+
# Add the leftover change to the TransactionOutput containing the change address
362+
self._outputs[change_output_index].amount = (
363+
changes[0].amount + self._outputs[change_output_index].amount
364+
)
365+
# if we enforce that TransactionOutputs must use Values for `amount`, we can use += here
366+
367+
else:
368+
self._outputs += changes
353369

354370
if change_address:
371+
372+
if merge_change:
373+
374+
for idx, output in enumerate(original_outputs):
375+
376+
# Find any transaction outputs which already contain the change address
377+
if change_address == output.address:
378+
if change_output_index is None or output.lovelace == 0:
379+
change_output_index = idx
380+
355381
# Set fee to max
356382
self.fee = self._estimate_fee()
357383
changes = self._calc_change(
358384
self.fee, self.inputs, self.outputs, change_address, precise_fee=True
359385
)
360-
self._outputs += changes
386+
387+
_merge_changes(changes)
361388

362389
# With changes included, we can estimate the fee more precisely
363390
self.fee = self._estimate_fee()
@@ -367,7 +394,8 @@ def _add_change_and_fee(
367394
changes = self._calc_change(
368395
self.fee, self.inputs, self.outputs, change_address, precise_fee=True
369396
)
370-
self._outputs += changes
397+
398+
_merge_changes(changes)
371399

372400
return self
373401

@@ -649,12 +677,18 @@ def _estimate_fee(self):
649677

650678
return estimated_fee
651679

652-
def build(self, change_address: Optional[Address] = None) -> TransactionBody:
680+
def build(
681+
self,
682+
change_address: Optional[Address] = None,
683+
merge_change: Optional[bool] = False,
684+
) -> TransactionBody:
653685
"""Build a transaction body from all constraints set through the builder.
654686
655687
Args:
656688
change_address (Optional[Address]): Address to which changes will be returned. If not provided, the
657689
transaction body will likely be unbalanced (sum of inputs is greater than the sum of outputs).
690+
merge_change (Optional[bool]): If the change address match one of the transaction output, the change amount
691+
will be directly added to that transaction output, instead of being added as a separate output.
658692
659693
Returns:
660694
TransactionBody: A transaction body.
@@ -773,7 +807,7 @@ def build(self, change_address: Optional[Address] = None) -> TransactionBody:
773807

774808
self._set_redeemer_index()
775809

776-
self._add_change_and_fee(change_address)
810+
self._add_change_and_fee(change_address, merge_change=merge_change)
777811

778812
tx_body = self._build_tx_body()
779813

@@ -820,6 +854,7 @@ def build_and_sign(
820854
self,
821855
signing_keys: List[Union[SigningKey, ExtendedSigningKey]],
822856
change_address: Optional[Address] = None,
857+
merge_change: Optional[bool] = False,
823858
) -> Transaction:
824859
"""Build a transaction body from all constraints set through the builder and sign the transaction with
825860
provided signing keys.
@@ -829,12 +864,14 @@ def build_and_sign(
829864
sign the transaction.
830865
change_address (Optional[Address]): Address to which changes will be returned. If not provided, the
831866
transaction body will likely be unbalanced (sum of inputs is greater than the sum of outputs).
867+
merge_change (Optional[bool]): If the change address match one of the transaction output, the change amount
868+
will be directly added to that transaction output, instead of being added as a separate output.
832869
833870
Returns:
834871
Transaction: A signed transaction.
835872
"""
836873

837-
tx_body = self.build(change_address=change_address)
874+
tx_body = self.build(change_address=change_address, merge_change=merge_change)
838875
witness_set = self.build_witness_set()
839876
witness_set.vkey_witnesses = []
840877

test/pycardano/test_txbuilder.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,116 @@ def test_tx_builder_withdrawal(chain_context):
703703
}
704704

705705
assert expected == tx_body.to_primitive()
706+
707+
708+
def test_tx_builder_no_output(chain_context):
709+
tx_builder = TransactionBuilder(chain_context)
710+
sender = "addr_test1vrm9x2zsux7va6w892g38tvchnzahvcd9tykqf3ygnmwtaqyfg52x"
711+
sender_address = Address.from_primitive(sender)
712+
713+
input_amount = 10000000
714+
715+
tx_in1 = TransactionInput.from_primitive([b"1" * 32, 3])
716+
tx_out1 = TransactionOutput.from_primitive([sender, input_amount])
717+
utxo1 = UTxO(tx_in1, tx_out1)
718+
719+
tx_builder.add_input(utxo1)
720+
721+
tx_body = tx_builder.build(change_address=sender_address, merge_change=True)
722+
723+
expected = {
724+
0: [[b"11111111111111111111111111111111", 3]],
725+
1: [
726+
[sender_address.to_primitive(), 9836215],
727+
],
728+
2: 163785,
729+
}
730+
731+
assert expected == tx_body.to_primitive()
732+
733+
734+
def test_tx_builder_merge_change_to_output(chain_context):
735+
tx_builder = TransactionBuilder(chain_context)
736+
sender = "addr_test1vrm9x2zsux7va6w892g38tvchnzahvcd9tykqf3ygnmwtaqyfg52x"
737+
sender_address = Address.from_primitive(sender)
738+
739+
input_amount = 10000000
740+
741+
tx_in1 = TransactionInput.from_primitive([b"1" * 32, 3])
742+
tx_out1 = TransactionOutput.from_primitive([sender, input_amount])
743+
utxo1 = UTxO(tx_in1, tx_out1)
744+
745+
tx_builder.add_input(utxo1)
746+
tx_builder.add_output(TransactionOutput.from_primitive([sender, 10000]))
747+
748+
tx_body = tx_builder.build(change_address=sender_address, merge_change=True)
749+
750+
expected = {
751+
0: [[b"11111111111111111111111111111111", 3]],
752+
1: [
753+
[sender_address.to_primitive(), 9836215],
754+
],
755+
2: 163785,
756+
}
757+
758+
assert expected == tx_body.to_primitive()
759+
760+
761+
def test_tx_builder_merge_change_to_output_2(chain_context):
762+
tx_builder = TransactionBuilder(chain_context)
763+
sender = "addr_test1vrm9x2zsux7va6w892g38tvchnzahvcd9tykqf3ygnmwtaqyfg52x"
764+
sender_address = Address.from_primitive(sender)
765+
receiver = "addr_test1vr2p8st5t5cxqglyjky7vk98k7jtfhdpvhl4e97cezuhn0cqcexl7"
766+
receiver_address = Address.from_primitive(receiver)
767+
768+
input_amount = 10000000
769+
770+
tx_in1 = TransactionInput.from_primitive([b"1" * 32, 3])
771+
tx_out1 = TransactionOutput.from_primitive([sender, input_amount])
772+
utxo1 = UTxO(tx_in1, tx_out1)
773+
774+
tx_builder.add_input(utxo1)
775+
tx_builder.add_output(TransactionOutput.from_primitive([sender, 10000]))
776+
tx_builder.add_output(TransactionOutput.from_primitive([receiver, 10000]))
777+
tx_builder.add_output(TransactionOutput.from_primitive([sender, 0]))
778+
779+
tx_body = tx_builder.build(change_address=sender_address, merge_change=True)
780+
781+
expected = {
782+
0: [[b"11111111111111111111111111111111", 3]],
783+
1: [
784+
[sender_address.to_primitive(), 10000],
785+
[receiver_address.to_primitive(), 10000],
786+
[sender_address.to_primitive(), 9813135],
787+
],
788+
2: 166865,
789+
}
790+
791+
assert expected == tx_body.to_primitive()
792+
793+
794+
def test_tx_builder_merge_change_to_zero_amount_output(chain_context):
795+
tx_builder = TransactionBuilder(chain_context)
796+
sender = "addr_test1vrm9x2zsux7va6w892g38tvchnzahvcd9tykqf3ygnmwtaqyfg52x"
797+
sender_address = Address.from_primitive(sender)
798+
799+
input_amount = 10000000
800+
801+
tx_in1 = TransactionInput.from_primitive([b"1" * 32, 3])
802+
tx_out1 = TransactionOutput.from_primitive([sender, input_amount])
803+
utxo1 = UTxO(tx_in1, tx_out1)
804+
805+
tx_builder.add_input(utxo1)
806+
tx_builder.add_output(TransactionOutput.from_primitive([sender, 0]))
807+
808+
tx_body = tx_builder.build(change_address=sender_address, merge_change=True)
809+
810+
expected = {
811+
0: [[b"11111111111111111111111111111111", 3]],
812+
1: [
813+
[sender_address.to_primitive(), 9836215],
814+
],
815+
2: 163785,
816+
}
817+
818+
assert expected == tx_body.to_primitive()

0 commit comments

Comments
 (0)