Skip to content

Add automatic snapshot feature for QEMU VMs #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/aleph/vm/controllers/qemu/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import qmp
from pydantic import BaseModel


logger = logging.getLogger(__name__)


class VmSevInfo(BaseModel):
enabled: bool
api_major: int
Expand Down Expand Up @@ -74,3 +78,46 @@ def query_status(self) -> None:
"""
# {'status': 'prelaunch', 'singlestep': False, 'running': False}
return self.qmp_client.command("query-status")

def create_snapshot(self, snapshot_name: str) -> bool:
"""
Create a VM snapshot using QMP. This will snapshot the VM's RAM state and disks.

:param snapshot_name: Name of the snapshot
:return: True if successful, False otherwise
"""
try:
logger.debug(f"Creating snapshot {snapshot_name} for VM {self.vm.vm_id}")
self.qmp_client.command("savevm", **{"name": snapshot_name})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.qmp_client.command("savevm", **{"name": snapshot_name})
self.qmp_client.command("savevm", name=snapshot_name)

return True
except Exception as e:
logger.error(f"Failed to create snapshot {snapshot_name} for VM {self.vm.vm_id}: {e}")
return False

def delete_snapshot(self, snapshot_name: str) -> bool:
"""
Delete a VM snapshot using QMP.

:param snapshot_name: Name of the snapshot to delete
:return: True if successful, False otherwise
"""
try:
logger.debug(f"Deleting snapshot {snapshot_name} for VM {self.vm.vm_id}")
self.qmp_client.command("delvm", **{"name": snapshot_name})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.qmp_client.command("delvm", **{"name": snapshot_name})
self.qmp_client.command("delvm", name=snapshot_name)

return True
except Exception as e:
logger.error(f"Failed to delete snapshot {snapshot_name} for VM {self.vm.vm_id}: {e}")
return False

def list_snapshots(self) -> list[str]:
"""
List all VM snapshots using QMP.

:return: List of snapshot names
"""
try:
snapshots = self.qmp_client.command("query-snapshots")
return [snapshot["name"] for snapshot in snapshots]
except Exception as e:
logger.error(f"Failed to list snapshots for VM {self.vm.vm_id}: {e}")
return []
52 changes: 51 additions & 1 deletion src/aleph/vm/controllers/qemu/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
AlephFirecrackerResources,
VmSetupError,
)
from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot
from aleph.vm.controllers.interface import AlephVmControllerInterface
from aleph.vm.controllers.qemu.client import QemuVmClient
from aleph.vm.controllers.qemu.cloudinit import CloudInitMixin
from aleph.vm.network.firewall import teardown_nftables_for_vm
from aleph.vm.network.interfaces import TapInterface
Expand Down Expand Up @@ -109,9 +111,10 @@ class AlephQemuInstance(Generic[ConfigurationType], CloudInitMixin, AlephVmContr
vm_configuration: ConfigurationType | None
is_instance: bool
qemu_process: Process | None
support_snapshot = False
support_snapshot = True
persistent = True
controller_configuration: Configuration
active_snapshot_name: str | None = None

def __repr__(self):
return f"<AlephQemuInstance {self.vm_id}>"
Expand Down Expand Up @@ -272,3 +275,50 @@ async def teardown(self):
if self.tap_interface:
await self.tap_interface.delete()
await self.stop_guest_api()

async def create_snapshot(self) -> CompressedDiskVolumeSnapshot:
"""
Create a VM snapshot using QMP's native savevm functionality.

:return: CompressedDiskVolumeSnapshot object (placeholder since QEMU handles snapshots internally)
"""
logger.debug(f"Creating snapshot for VM {self.vm_id} ({self.vm_hash})")

# Generate a snapshot name
snapshot_name = f"auto-snapshot-{self.vm_hash}"

try:
with QemuVmClient(self) as client:
# Create new snapshot first for safety
success = client.create_snapshot(snapshot_name)
if not success:
msg = f"Failed to create snapshot {snapshot_name}"
raise ValueError(msg)

logger.debug(f"Successfully created snapshot {snapshot_name}")

# Get current snapshots
existing_snapshots = client.list_snapshots()

# Delete previous snapshots if they exist and are different from the new one
for existing_snapshot in existing_snapshots:
if existing_snapshot != snapshot_name and existing_snapshot.startswith("auto-snapshot-"):
logger.debug(f"Deleting previous snapshot {existing_snapshot}")
client.delete_snapshot(existing_snapshot)

self.active_snapshot_name = snapshot_name

# Return a placeholder snapshot object since QEMU handles snapshots internally
placeholder_path = Path(settings.EXECUTION_ROOT) / f"{self.vm_hash}-snapshot-info"
with open(placeholder_path, "w") as f:
f.write(f"QEMU snapshot {snapshot_name} created successfully")

return CompressedDiskVolumeSnapshot(
path=placeholder_path,
algorithm=settings.SNAPSHOT_COMPRESSION_ALGORITHM
)

except Exception as e:
msg = f"Failed to create snapshot for VM {self.vm_id}: {e}"
logger.error(msg)
raise ValueError(msg) from e
121 changes: 121 additions & 0 deletions src/aleph/vm/controllers/qemu/snapshot_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import asyncio
import logging
import threading
from time import sleep

from aleph_message.models import ItemHash
from schedule import Job, Scheduler

from aleph.vm.conf import settings
from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot

logger = logging.getLogger(__name__)


def wrap_async_snapshot(vm):
asyncio.run(do_vm_snapshot(vm))


def run_threaded_snapshot(vm):
job_thread = threading.Thread(target=wrap_async_snapshot, args=(vm,))
job_thread.start()


async def do_vm_snapshot(vm) -> CompressedDiskVolumeSnapshot:
try:
logger.debug(f"Starting new snapshot for QEMU VM {vm.vm_hash}")
assert vm, "VM execution not set"

snapshot = await vm.create_snapshot()
logger.debug(f"New snapshot for QEMU VM {vm.vm_hash} created successfully")
return snapshot
except ValueError as error:
msg = "Failed to create QEMU VM snapshot"
raise ValueError(msg) from error


def infinite_run_scheduler_jobs(scheduler: Scheduler) -> None:
while True:
scheduler.run_pending()
sleep(1)


class QemuSnapshotExecution:
vm_hash: ItemHash
execution: any # AlephQemuInstance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
execution: any # AlephQemuInstance
execution: Any # AlephQemuInstance

This should be Any from typing, I'm surprised this worked

frequency: int
_scheduler: Scheduler
_job: Job

def __init__(
self,
scheduler: Scheduler,
vm_hash: ItemHash,
execution,
frequency: int,
):
self.vm_hash = vm_hash
self.execution = execution
self.frequency = frequency
self._scheduler = scheduler

async def start(self) -> None:
logger.debug(f"Starting QEMU snapshots for VM {self.vm_hash} every {self.frequency} minutes")
job = self._scheduler.every(self.frequency).minutes.do(run_threaded_snapshot, self.execution)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
job = self._scheduler.every(self.frequency).minutes.do(run_threaded_snapshot, self.execution)
self._job = self._scheduler.every(self.frequency).minutes.do(run_threaded_snapshot, self.execution)

And remove the next line

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
job = self._scheduler.every(self.frequency).minutes.do(run_threaded_snapshot, self.execution)
self._job = self._scheduler.every(self.frequency).minutes.do(run_threaded_snapshot, self.execution)

And remove the next line

self._job = job

async def stop(self) -> None:
logger.debug(f"Stopping QEMU snapshots for VM {self.vm_hash}")
self._scheduler.cancel_job(self._job)


class QemuSnapshotManager:
"""
Manage QEMU VM snapshots.
"""

executions: dict[ItemHash, QemuSnapshotExecution]
_scheduler: Scheduler

def __init__(self):
self.executions = {}
self._scheduler = Scheduler()

def run_in_thread(self) -> None:
job_thread = threading.Thread(
target=infinite_run_scheduler_jobs,
args=[self._scheduler],
daemon=True,
name="QemuSnapshotManager",
)
job_thread.start()

async def start_for(self, vm, frequency: int | None = None) -> None:
if not vm.support_snapshot:
msg = "Snapshots are not supported for this VM type."
raise NotImplementedError(msg)

# Default to 10 minutes if not specified and settings value is 0
default_frequency = frequency or settings.SNAPSHOT_FREQUENCY or 10

vm_hash = vm.vm_hash
snapshot_execution = QemuSnapshotExecution(
scheduler=self._scheduler,
vm_hash=vm_hash,
execution=vm,
frequency=default_frequency,
)
self.executions[vm_hash] = snapshot_execution
await snapshot_execution.start()

async def stop_for(self, vm_hash: ItemHash) -> None:
try:
snapshot_execution = self.executions.pop(vm_hash)
except KeyError:
logger.warning("Could not find snapshot task for QEMU instance %s", vm_hash)
return

await snapshot_execution.stop()

async def stop_all(self) -> None:
await asyncio.gather(*(self.stop_for(vm_hash) for vm_hash in list(self.executions.keys())))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this function also do a job_thread.stop()? I'm no sure

3 changes: 2 additions & 1 deletion src/aleph/vm/controllers/qemu_confidential/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ class AlephQemuConfidentialInstance(AlephQemuInstance):
vm_configuration: ConfigurationType | None
is_instance: bool
qemu_process: Process | None
support_snapshot = False
support_snapshot = True
persistent = True
_queue_cancellers: dict[asyncio.Queue, Callable] = {}
controller_configuration: Configuration
confidential_policy: int
active_snapshot_name: str | None = None

def __repr__(self):
return f"<AlephQemuInstance {self.vm_id}>"
Expand Down
11 changes: 9 additions & 2 deletions src/aleph/vm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from aleph.vm.controllers.firecracker.snapshot_manager import SnapshotManager
from aleph.vm.controllers.interface import AlephVmControllerInterface
from aleph.vm.controllers.qemu.instance import AlephQemuInstance, AlephQemuResources
from aleph.vm.controllers.qemu.snapshot_manager import QemuSnapshotManager
from aleph.vm.controllers.qemu_confidential.instance import (
AlephQemuConfidentialInstance,
AlephQemuConfidentialResources,
Expand Down Expand Up @@ -89,6 +90,7 @@ class VmExecution:
update_task: asyncio.Task | None = None

snapshot_manager: SnapshotManager | None
qemu_snapshot_manager: QemuSnapshotManager | None
systemd_manager: SystemDManager | None

persistent: bool = False
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
snapshot_manager: SnapshotManager | None,
systemd_manager: SystemDManager | None,
persistent: bool,
qemu_snapshot_manager: QemuSnapshotManager | None = None,
):
self.uuid = uuid.uuid1() # uuid1() includes the hardware address and timestamp
self.vm_hash = vm_hash
Expand All @@ -175,6 +178,7 @@ def __init__(
self.preparation_pending_lock = asyncio.Lock()
self.stop_pending_lock = asyncio.Lock()
self.snapshot_manager = snapshot_manager
self.qemu_snapshot_manager = qemu_snapshot_manager
self.systemd_manager = systemd_manager
self.persistent = persistent

Expand Down Expand Up @@ -379,8 +383,11 @@ async def stop(self) -> None:
self.cancel_expiration()
self.cancel_update()

if self.vm.support_snapshot and self.snapshot_manager:
await self.snapshot_manager.stop_for(self.vm_hash)
if self.vm.support_snapshot:
if isinstance(self.vm, AlephQemuInstance) and self.qemu_snapshot_manager:
await self.qemu_snapshot_manager.stop_for(self.vm_hash)
elif self.snapshot_manager:
await self.snapshot_manager.stop_for(self.vm_hash)
self.stop_event.set()

def start_watching_for_updates(self, pubsub: PubSub):
Expand Down
28 changes: 23 additions & 5 deletions src/aleph/vm/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from aleph.vm.conf import settings
from aleph.vm.controllers.firecracker.snapshot_manager import SnapshotManager
from aleph.vm.controllers.qemu.snapshot_manager import QemuSnapshotManager
from aleph.vm.controllers.qemu.instance import AlephQemuInstance
from aleph.vm.network.hostnetwork import Network, make_ipv6_allocator
from aleph.vm.orchestrator.metrics import get_execution_records
from aleph.vm.orchestrator.utils import update_aggregate_settings
Expand All @@ -43,6 +45,7 @@ class VmPool:
message_cache: dict[str, ExecutableMessage]
network: Network | None
snapshot_manager: SnapshotManager | None = None
qemu_snapshot_manager: QemuSnapshotManager | None = None
systemd_manager: SystemDManager
creation_lock: asyncio.Lock
gpus: List[GpuDevice] = []
Expand Down Expand Up @@ -73,15 +76,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.systemd_manager = SystemDManager()
if settings.SNAPSHOT_FREQUENCY > 0:
self.snapshot_manager = SnapshotManager()
self.qemu_snapshot_manager = QemuSnapshotManager()

def setup(self) -> None:
"""Set up the VM pool and the network."""
if self.network:
self.network.setup()

if self.snapshot_manager:
logger.debug("Initializing SnapshotManager ...")
logger.debug("Initializing SnapshotManager for Firecracker VMs...")
self.snapshot_manager.run_in_thread()

if self.qemu_snapshot_manager:
logger.debug("Initializing QemuSnapshotManager for QEMU VMs...")
self.qemu_snapshot_manager.run_in_thread()

if settings.ENABLE_GPU_SUPPORT:
# Refresh and get latest settings aggregate
Expand Down Expand Up @@ -116,6 +124,7 @@ async def create_a_vm(
snapshot_manager=self.snapshot_manager,
systemd_manager=self.systemd_manager,
persistent=persistent,
qemu_snapshot_manager=self.qemu_snapshot_manager,
)
self.executions[vm_hash] = execution

Expand Down Expand Up @@ -149,8 +158,12 @@ async def create_a_vm(
if execution.is_program and execution.vm:
await execution.vm.load_configuration()

if execution.vm and execution.vm.support_snapshot and self.snapshot_manager:
await self.snapshot_manager.start_for(vm=execution.vm)
if execution.vm and execution.vm.support_snapshot:
# Use appropriate snapshot manager based on VM type
if isinstance(execution.vm, AlephQemuInstance) and self.qemu_snapshot_manager:
await self.qemu_snapshot_manager.start_for(vm=execution.vm)
elif self.snapshot_manager:
await self.snapshot_manager.start_for(vm=execution.vm)
except Exception:
# ensure the VM is removed from the pool on creation error
self.forget_vm(vm_hash)
Expand Down Expand Up @@ -244,6 +257,7 @@ async def load_persistent_executions(self):
snapshot_manager=self.snapshot_manager,
systemd_manager=self.systemd_manager,
persistent=saved_execution.persistent,
qemu_snapshot_manager=self.qemu_snapshot_manager,
)

if execution.is_running:
Expand All @@ -266,8 +280,12 @@ async def load_persistent_executions(self):
self._schedule_forget_on_stop(execution)

# Start the snapshot manager for the VM
if vm.support_snapshot and self.snapshot_manager:
await self.snapshot_manager.start_for(vm=execution.vm)
if vm.support_snapshot:
# Use appropriate snapshot manager based on VM type
if isinstance(execution.vm, AlephQemuInstance) and self.qemu_snapshot_manager:
await self.qemu_snapshot_manager.start_for(vm=execution.vm)
elif self.snapshot_manager:
await self.snapshot_manager.start_for(vm=execution.vm)

self.executions[vm_hash] = execution
else:
Expand Down
Loading