Skip to content

Implement credits payment method #830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"aioredis==1.3.1",
"aiosqlite==0.19",
"alembic==1.13.1",
"aleph-message~=1.0.1",
# "aleph-message~=1.0.1",
"aleph-message @ git+https://github.com/aleph-im/aleph-message@andres-feature-implement_credits_payment",
"aleph-superfluid~=0.2.1",
"dbus-python==1.3.2",
"eth-account~=0.10",
Expand Down
71 changes: 49 additions & 22 deletions src/aleph/vm/orchestrator/payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from collections.abc import Iterable
from decimal import Decimal
from typing import List

import aiohttp
from aleph_message.models import ItemHash, PaymentType
Expand Down Expand Up @@ -44,45 +45,56 @@ async def fetch_balance_of_address(address: str) -> Decimal:
return resp_data["balance"]


async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal:
"""Fetch the flow price of an execution from the reference API server."""
async def fetch_credit_balance_of_address(address: str) -> Decimal:
"""
Get the balance of the user from the PyAleph API.

API Endpoint:
GET /api/v0/addresses/{address}/balance

For more details, see the PyAleph API documentation:
https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62
"""

async with aiohttp.ClientSession() as session:
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
url = f"{settings.API_SERVER}/api/v0/addresses/{address}/credit_balance"
resp = await session.get(url)

# Consider the balance as null if the address is not found
if resp.status == 404:
return Decimal(0)

# Raise an error if the request failed
resp.raise_for_status()

resp_data = await resp.json()
required_flow: float = resp_data["required_tokens"]
payment_type: str | None = resp_data["payment_type"]

if payment_type is None:
msg = "Payment type must be specified in the message"
raise ValueError(msg)
elif payment_type != PaymentType.superfluid:
msg = f"Payment type {payment_type} is not supported"
raise ValueError(msg)

return Decimal(required_flow)
return resp_data["credits"]


async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal:
"""Fetch the hold price of an execution from the reference API server."""
async def fetch_execution_price(
item_hash: ItemHash, allowed_payments: List[PaymentType], payment_type_required: bool = True
) -> Decimal:
"""Fetch the credit price of an execution from the reference API server."""
async with aiohttp.ClientSession() as session:
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
resp = await session.get(url)
# Raise an error if the request failed
resp.raise_for_status()

resp_data = await resp.json()
required_hold: float = resp_data["required_tokens"]
required_credits: float = resp_data["required_credits"] # Field not defined yet on API side.
payment_type: str | None = resp_data["payment_type"]

if payment_type not in (None, PaymentType.hold):
msg = f"Payment type {payment_type} is not supported"
if payment_type_required and payment_type is None:
msg = "Payment type must be specified in the message"
raise ValueError(msg)

return Decimal(required_hold)
if payment_type:
if payment_type not in allowed_payments:
msg = f"Payment type {payment_type} is not supported"
raise ValueError(msg)

return Decimal(required_credits)


class InvalidAddressError(ValueError):
Expand Down Expand Up @@ -133,11 +145,26 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal:

async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal:
"""Get the balance required for the resources of the user from the messages and the pricing aggregate."""
costs = await asyncio.gather(*(fetch_execution_hold_price(execution.vm_hash) for execution in executions))
costs = await asyncio.gather(
*(
fetch_execution_price(execution.vm_hash, [PaymentType.hold], payment_type_required=False)
for execution in executions
)
)
return sum(costs, Decimal(0))


async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal:
"""Get the balance required for the resources of the user from the messages and the pricing aggregate."""
costs = await asyncio.gather(
*(fetch_execution_price(execution.vm_hash, [PaymentType.credit]) for execution in executions)
)
return sum(costs, Decimal(0))


async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal:
"""Compute the flow required for a collection of executions, typically all executions from a specific address"""
flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions))
flows = await asyncio.gather(
*(fetch_execution_price(execution.vm_hash, [PaymentType.superfluid]) for execution in executions)
)
return sum(flows, Decimal(0))
53 changes: 41 additions & 12 deletions src/aleph/vm/orchestrator/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from .messages import get_message_status
from .payment import (
compute_required_balance,
compute_required_credit_balance,
compute_required_flow,
fetch_balance_of_address,
fetch_credit_balance_of_address,
get_stream,
)
from .pubsub import PubSub
Expand Down Expand Up @@ -187,44 +189,71 @@ async def check_payment(pool: VmPool):
pool.forget_vm(vm_hash)

# Check if the balance held in the wallet is sufficient holder tier resources (Not do it yet)
for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.hold).items():
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.hold).items():
for chain, executions in chains.items():
executions = [execution for execution in executions if execution.is_confidential]
if not executions:
continue
balance = await fetch_balance_of_address(sender)
balance = await fetch_balance_of_address(execution_address)

# Stop executions until the required balance is reached
required_balance = await compute_required_balance(executions)
logger.debug(f"Required balance for Sender {sender} executions: {required_balance}, {executions}")
logger.debug(
f"Required balance for Sender {execution_address} executions: {required_balance}, {executions}"
)
# Stop executions until the required balance is reached
while executions and balance < (required_balance + settings.PAYMENT_BUFFER):
last_execution = executions.pop(-1)
logger.debug(f"Stopping {last_execution} due to insufficient balance")
await pool.stop_vm(last_execution.vm_hash)
required_balance = await compute_required_balance(executions)

community_wallet = await get_community_wallet_address()
if not community_wallet:
logger.error("Monitor payment ERROR: No community wallet set. Cannot check community payment")

# Check if the credit balance held in the wallet is sufficient credit tier resources (Not do it yet)
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.credit).items():
for chain, executions in chains.items():
executions = [execution for execution in executions]
if not executions:
continue
balance = await fetch_credit_balance_of_address(execution_address)

# Stop executions until the required credits are reached
required_credits = await compute_required_credit_balance(executions)
logger.debug(
f"Required credit balance for Address {execution_address} executions: {required_credits}, {executions}"
)
# Stop executions until the required credits are reached
while executions and balance < (required_credits + settings.PAYMENT_BUFFER):
last_execution = executions.pop(-1)
logger.debug(f"Stopping {last_execution} due to insufficient credit balance")
await pool.stop_vm(last_execution.vm_hash)
required_credits = await compute_required_credit_balance(executions)

# Check if the balance held in the wallet is sufficient stream tier resources
for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items():
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.superfluid).items():
for chain, executions in chains.items():
try:
stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain)
stream = await get_stream(
sender=execution_address, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain
)

logger.debug(
f"Stream flow from {sender} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}"
f"Stream flow from {execution_address} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}"
)
except ValueError as error:
logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}")
logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}")
continue
try:
community_stream = await get_stream(sender=sender, receiver=community_wallet, chain=chain)
logger.debug(f"Stream flow from {sender} to {community_wallet} (community) : {stream} {chain}")
community_stream = await get_stream(sender=execution_address, receiver=community_wallet, chain=chain)
logger.debug(
f"Stream flow from {execution_address} to {community_wallet} (community) : {stream} {chain}"
)

except ValueError as error:
logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}")
logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}")
continue

while executions:
Expand All @@ -249,7 +278,7 @@ async def check_payment(pool: VmPool):
)
required_community_stream = format_cost(required_stream * COMMUNITY_STREAM_RATIO)
logger.debug(
f"Stream for senders {sender} {len(executions)} executions. CRN : {stream} / {required_crn_stream}."
f"Stream for senders {execution_address} {len(executions)} executions. CRN : {stream} / {required_crn_stream}."
f"Community: {community_stream} / {required_community_stream}"
)
# Can pay all executions
Expand All @@ -259,7 +288,7 @@ async def check_payment(pool: VmPool):
break
# Stop executions until the required stream is reached
last_execution = executions.pop(-1)
logger.info(f"Stopping {last_execution} of {sender} due to insufficient stream")
logger.info(f"Stopping {last_execution} of {execution_address} due to insufficient stream")
await pool.stop_vm(last_execution.vm_hash)


Expand Down
6 changes: 3 additions & 3 deletions src/aleph/vm/orchestrator/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from aleph.vm.orchestrator.payment import (
InvalidAddressError,
InvalidChainError,
fetch_execution_flow_price,
fetch_execution_price,
get_stream,
)
from aleph.vm.orchestrator.pubsub import PubSub
Expand Down Expand Up @@ -577,7 +577,7 @@ async def notify_allocation(request: web.Request):
if have_gpu:
logger.debug(f"GPU Instance {item_hash} not using PAYG")
user_balance = await payment.fetch_balance_of_address(message.sender)
hold_price = await payment.fetch_execution_hold_price(item_hash)
hold_price = await payment.fetch_execution_price(item_hash, [PaymentType.hold], False)
logger.debug(f"Address {message.sender} Balance: {user_balance}, Price: {hold_price}")
if hold_price > user_balance:
return web.HTTPPaymentRequired(
Expand Down Expand Up @@ -606,7 +606,7 @@ async def notify_allocation(request: web.Request):
if not active_flow:
raise web.HTTPPaymentRequired(reason="Empty payment stream for this instance")

required_flow: Decimal = await fetch_execution_flow_price(item_hash)
required_flow: Decimal = await fetch_execution_price(item_hash, [PaymentType.superfluid])
community_wallet = await get_community_wallet_address()
required_crn_stream: Decimal
required_community_stream: Decimal
Expand Down
12 changes: 6 additions & 6 deletions src/aleph/vm/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def get_available_gpus(self) -> list[GpuDevice]:
available_gpus.append(gpu)
return available_gpus

def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]:
def get_executions_by_address(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]:
"""Return all executions of the given type, grouped by sender and by chain."""
executions_by_sender: dict[str, dict[str, list[VmExecution]]] = {}
executions_by_address: dict[str, dict[str, list[VmExecution]]] = {}
for vm_hash, execution in self.executions.items():
if execution.vm_hash in (settings.CHECK_FASTAPI_VM_ID, settings.LEGACY_CHECK_FASTAPI_VM_ID):
# Ignore Diagnostic VM execution
Expand All @@ -399,11 +399,11 @@ def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[
else Payment(chain=Chain.ETH, type=PaymentType.hold)
)
if execution_payment.type == payment_type:
sender = execution.message.address
address = execution.message.address
chain = execution_payment.chain
executions_by_sender.setdefault(sender, {})
executions_by_sender[sender].setdefault(chain, []).append(execution)
return executions_by_sender
executions_by_address.setdefault(address, {})
executions_by_address[address].setdefault(chain, []).append(execution)
return executions_by_address

def get_valid_reservation(self, resource) -> Reservation | None:
if resource in self.reservations and self.reservations[resource].is_expired():
Expand Down
8 changes: 4 additions & 4 deletions tests/supervisor/test_checkpayment.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def compute_required_flow(executions):

pool.executions = {hash: execution}

executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
assert len(executions_by_sender) == 1
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}

Expand Down Expand Up @@ -136,7 +136,7 @@ async def compute_required_flow(executions):

pool.executions = {hash: execution}

executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
assert len(executions_by_sender) == 1
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}

Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_not_enough_flow(mocker, fake_instance_content):

pool.executions = {hash: execution}

executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
assert len(executions_by_sender) == 1
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}

Expand Down Expand Up @@ -217,7 +217,7 @@ async def get_stream(sender, receiver, chain):

pool.executions = {hash: execution}

executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
assert len(executions_by_sender) == 1
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}

Expand Down
Loading