Skip to content

Commit

Permalink
Merge pull request #2144 from ogayot/pc-kernel
Browse files Browse the repository at this point in the history
Fix the pc-kernel issue
  • Loading branch information
ogayot authored Jan 28, 2025
2 parents bc5a03a + 9218456 commit 23571a5
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 19 deletions.
20 changes: 14 additions & 6 deletions subiquity/models/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,27 @@ def get_matching_source(self, id_: str) -> CatalogEntry:
raise KeyError

def get_source(
self, variation_name: typing.Optional[str] = None
self,
variation_name: typing.Optional[str] = None,
*,
source_id: typing.Optional[str] = None,
) -> typing.Optional[str]:
scheme = self.current.type
if source_id is None:
source = self.current
else:
source = self.get_matching_source(source_id)

scheme = source.type
if scheme is None:
return None
if variation_name is None:
variation = next(iter(self.current.variations.values()))
variation = next(iter(source.variations.values()))
else:
variation = self.current.variations[variation_name]
variation = source.variations[variation_name]
path = os.path.join(self._dir, variation.path)
if self.current.preinstalled_langs:
if source.preinstalled_langs:
base, ext = os.path.splitext(path)
if self.lang in self.current.preinstalled_langs:
if self.lang in source.preinstalled_langs:
suffix = self.lang
else:
suffix = "no-languages"
Expand Down
28 changes: 27 additions & 1 deletion subiquity/server/controllers/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,33 @@ async def _examine_systems(self):
system = None
label = variation.snapd_system_label
if label is not None:
system, in_live_layer = await self._system_getter.get(name, label)
# We do not want to unconditionally propagate cancellation to
# _system_getter.get. If it gets cancelled during its critical
# section, it won't be able to properly clean up after itself
# (see LP: #2084032).
# Therefore we use an asyncio.Task (coupled with
# asyncio.shield) so we can prevent propagation.
in_critical_section = asyncio.Event()
task = asyncio.create_task(
self._system_getter.get(
name,
label,
source_id=catalog_entry.id,
started_event=in_critical_section,
)
)
try:
system, in_live_layer = await asyncio.shield(task)
except asyncio.CancelledError:
if not in_critical_section.is_set():
# Just to make sure we don't end up with a large queue of
# _system_getter.get() tasks.
task.cancel()
# _system_getter.get is marked async_helpers.exclusive
# so it should be safe to let it finish "unsupervised" even
# though it might be called again concurrently.
raise

log.debug("got system %s for variation %s", system, name)
if system is not None and len(system.volumes) > 0:
if not self.app.opts.enhanced_secureboot:
Expand Down
7 changes: 5 additions & 2 deletions subiquity/server/controllers/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ async def GET(self) -> SourceSelectionAndSetting:
)

def get_handler(
self, variation_name: Optional[str] = None
self,
variation_name: Optional[str] = None,
*,
source_id: Optional[str] = None,
) -> Optional[AbstractSourceHandler]:
source = self.model.get_source(variation_name)
source = self.model.get_source(variation_name, source_id=source_id)
if source is None:
return None
handler = get_handler_for_source(sanitize_source(source))
Expand Down
8 changes: 5 additions & 3 deletions subiquity/server/controllers/tests/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ async def test__get_system_api_error_logged(self):
getter = SystemGetter(self.app)

@contextlib.asynccontextmanager
async def mounted(self):
async def mounted(self, *, source_id):
yield

mount_mock = mock.patch(
Expand Down Expand Up @@ -612,7 +612,9 @@ async def mounted(self):
"subiquity.server.snapd.system_getter", level="WARNING"
) as logs:
await getter.get(
variation_name="minimal", label="enhanced-secureboot-desktop"
variation_name="minimal",
label="enhanced-secureboot-desktop",
source_id="default",
)

self.assertIn("cannot load assertions for label", logs.output[0])
Expand Down Expand Up @@ -2008,7 +2010,7 @@ def setUp(self):
self.fsc.model = make_model(Bootloader.UEFI)

@contextlib.asynccontextmanager
async def mounted(self):
async def mounted(self, *, source_id):
yield

p = mock.patch(
Expand Down
18 changes: 12 additions & 6 deletions subiquity/server/snapd/system_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from subiquity.server.mounter import Mounter
from subiquity.server.snapd.types import SystemDetails
from subiquitycore import async_helpers

log = logging.getLogger("subiquity.server.snapd.system_getter")

Expand All @@ -36,8 +37,10 @@ def __init__(self, app, variation_name):
self.app = app
self.variation_name = variation_name

async def mount(self):
source_handler = self.app.controllers.Source.get_handler(self.variation_name)
async def mount(self, *, source_id: Optional[str] = None):
source_handler = self.app.controllers.Source.get_handler(
self.variation_name, source_id=source_id
)
if source_handler is None:
raise NoSnapdSystemsOnSource
mounter = Mounter(self.app)
Expand All @@ -61,8 +64,8 @@ async def mount(self):
return source_handler, mounter

@contextlib.asynccontextmanager
async def mounted(self):
source_handler, mounter = await self.mount()
async def mounted(self, *, source_id: Optional[str] = None):
source_handler, mounter = await self.mount(source_id=source_id)
try:
yield
finally:
Expand All @@ -81,8 +84,9 @@ async def _get(self, label: str) -> SystemDetails:
log.warning("v2/systems/%s returned %s", label, http_err.response.text)
raise

@async_helpers.exclusive
async def get(
self, variation_name: str, label: str
self, variation_name: str, label: str, *, source_id: str
) -> Tuple[Optional[SystemDetails], bool]:
"""Return system information for a given system label.
Expand All @@ -97,7 +101,9 @@ async def get(
return await self._get(label), True
else:
try:
async with SystemsDirMounter(self.app, variation_name).mounted():
async with SystemsDirMounter(self.app, variation_name).mounted(
source_id=source_id
):
return await self._get(label), False
except NoSnapdSystemsOnSource:
return None, False
20 changes: 20 additions & 0 deletions subiquitycore/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import concurrent.futures
import logging
from typing import Optional

log = logging.getLogger("subiquitycore.async_helpers")

Expand Down Expand Up @@ -117,3 +118,22 @@ def done(self):
if self.task is None:
return False
return self.task.done()


def exclusive(coroutine_function):
"""Can be used to decorate a coroutine function that we do not want to run
multiple times concurrently. It uses a lock internally.
If the caller needs to know when the decorated coroutine starts executing
(i.e., when it has acquired the exclusive lock), they can pass an
asyncio.Event as the "started_event" keyword-only argument.
"""
lock = asyncio.Lock()

async def wrapped(*args, started_event: Optional[asyncio.Event] = None, **kwargs):
async with lock:
if started_event is not None:
started_event.set()

return await coroutine_function(*args, **kwargs)

return wrapped
33 changes: 32 additions & 1 deletion subiquitycore/tests/test_async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import unittest
from unittest.mock import AsyncMock

from subiquitycore.async_helpers import SingleInstanceTask, TaskAlreadyRunningError
from subiquitycore.async_helpers import (
SingleInstanceTask,
TaskAlreadyRunningError,
exclusive,
)
from subiquitycore.tests.parameterized import parameterized


Expand Down Expand Up @@ -63,3 +67,30 @@ async def fn():
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(sit.wait(), timeout=0.1)
self.assertFalse(sit.done())


class TestExclusive(unittest.IsolatedAsyncioTestCase):
async def test_concurrency(self):
timeout = 0.1
barrier = asyncio.Barrier(parties=2)

async def f():
async with barrier:
pass

await asyncio.wait_for(asyncio.gather(f(), f()), timeout=timeout)

g = exclusive(f)

with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.gather(g(), g()), timeout=timeout)

# This is the same as g, but just to show an example of the intended
# usage.
@exclusive
async def e():
async with barrier:
pass

with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.gather(e(), e()), timeout=timeout)

0 comments on commit 23571a5

Please sign in to comment.