Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/fromager/bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def __init__(
self.multiple_versions = multiple_versions
self.cache_wheel_server_url = cache_wheel_server_url
# Session-level resolution cache to avoid re-resolving same requirements
# Key: (requirement_string, pre_built) to distinguish source vs prebuilt
# Key: (requirement_string, pre_built, is_top_level) to distinguish
# source vs prebuilt and top-level vs non-top-level (cooldown bypass)
# Value: tuple of (url, version) tuples sorted by version (highest first)
# Values are stored as immutable tuples to prevent accidental corruption
# when callers modify the returned reference.
self._resolved_requirements: dict[
tuple[str, bool], tuple[tuple[str, Version], ...]
tuple[str, bool, bool], tuple[tuple[str, Version], ...]
] = {}

def resolve(
Expand Down Expand Up @@ -116,8 +117,10 @@ def resolve(
f"Git URL requirements must be handled by Bootstrapper: {req}"
)

# Check session cache (keyed by requirement + pre_built)
cached_result = self.get_cached_resolution(req, pre_built)
# Check session cache
cached_result = self.get_cached_resolution(
req, pre_built, req_type == RequirementType.TOP_LEVEL
)
if cached_result is not None:
logger.debug(f"resolved {req} from cache")
return list(cached_result) if return_all_versions else [cached_result[0]]
Expand Down Expand Up @@ -180,7 +183,9 @@ def resolve(

# Only cache non-empty results.
if results:
self.cache_resolution(req, pre_built, results)
self.cache_resolution(
req, pre_built, req_type == RequirementType.TOP_LEVEL, results
)

if not results:
return []
Expand Down Expand Up @@ -229,6 +234,7 @@ def get_cached_resolution(
self,
req: Requirement,
pre_built: bool,
is_top_level: bool,
) -> tuple[tuple[str, Version], ...] | None:
"""Get a cached resolution result if it exists.

Expand All @@ -237,17 +243,19 @@ def get_cached_resolution(
Args:
req: Package requirement to look up in cache
pre_built: Whether looking for prebuilt or source resolution
is_top_level: Whether this is a top-level requirement

Returns:
Tuple of (url, version) tuples if cached, None otherwise
"""
cache_key = (str(req), pre_built)
cache_key = (str(req), pre_built, is_top_level)
return self._resolved_requirements.get(cache_key)

def cache_resolution(
self,
req: Requirement,
pre_built: bool,
is_top_level: bool,
result: list[tuple[str, Version]],
) -> None:
"""Cache a resolution result.
Expand All @@ -261,9 +269,10 @@ def cache_resolution(
Args:
req: Package requirement to cache
pre_built: Whether this is a prebuilt or source resolution
is_top_level: Whether this is a top-level requirement
result: List of (url, version) tuples
"""
cache_key = (str(req), pre_built)
cache_key = (str(req), pre_built, is_top_level)
self._resolved_requirements[cache_key] = tuple(result)

def _resolve_from_graph(
Expand Down
8 changes: 6 additions & 2 deletions src/fromager/bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def resolve_versions(

# Check cache first to avoid re-resolving
# Git URLs are always source (not prebuilt)
cached_result = self._resolver.get_cached_resolution(req, pre_built=False)
cached_result = self._resolver.get_cached_resolution(
req, pre_built=False, is_top_level=True
)
if cached_result is not None:
logger.debug(f"resolved {req} from cache")
return (
Expand All @@ -299,7 +301,9 @@ def resolve_versions(
# Cache the git URL resolution (always source, not prebuilt)
# Store as list for consistency with cache structure
result = [(source_url, resolved_version)]
self._resolver.cache_resolution(req, pre_built=False, result=result)
self._resolver.cache_resolution(
req, pre_built=False, is_top_level=True, result=result
)
return result # Git URLs always return single version

# Delegate to RequirementResolver
Expand Down
51 changes: 47 additions & 4 deletions tests/test_bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,17 @@ def test_cache_resolution_stores_immutable_tuple(tmp_context: WorkContext) -> No
req = Requirement("mypkg>=1.0")
original = [("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))]

resolver.cache_resolution(req, pre_built=False, result=original)
cached = resolver.get_cached_resolution(req, pre_built=False)
resolver.cache_resolution(req, pre_built=False, is_top_level=True, result=original)
cached = resolver.get_cached_resolution(req, pre_built=False, is_top_level=True)

# Cached value should be a tuple
assert isinstance(cached, tuple)

# Mutating the original list must not affect the cache
original.append(("https://example.com/mypkg-2.0.tar.gz", Version("2.0")))
cached_after = resolver.get_cached_resolution(req, pre_built=False)
cached_after = resolver.get_cached_resolution(
req, pre_built=False, is_top_level=True
)
assert cached_after is not None
assert len(cached_after) == 1

Expand All @@ -589,9 +591,10 @@ def test_get_cached_resolution_returns_immutable(tmp_context: WorkContext) -> No
resolver.cache_resolution(
req,
pre_built=False,
is_top_level=True,
result=[("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))],
)
cached = resolver.get_cached_resolution(req, pre_built=False)
cached = resolver.get_cached_resolution(req, pre_built=False, is_top_level=True)
assert cached is not None

with pytest.raises(AttributeError):
Expand Down Expand Up @@ -642,6 +645,46 @@ def test_resolve_cache_returns_independent_lists(
mock_resolve.assert_called_once()


@patch("fromager.resolver.find_all_matching_from_provider")
def test_resolve_toplevel_after_transitive_uses_separate_cache(
mock_resolve: MagicMock,
tmp_context: WorkContext,
) -> None:
"""Top-level resolution re-resolves even if same req was cached as transitive."""
req = Requirement("mypkg>=1.0")
mock_resolve.side_effect = [
# First call: transitive resolution (cooldown blocks v2.0)
[("https://files.test/mypkg-1.5.tar.gz", Version("1.5"))],
# Second call: top-level resolution (cooldown bypassed, v2.0 included)
[
("https://files.test/mypkg-2.0.tar.gz", Version("2.0")),
("https://files.test/mypkg-1.5.tar.gz", Version("1.5")),
],
]

resolver = BootstrapRequirementResolver(tmp_context)

# First: resolve as transitive dependency
results_install = resolver.resolve(
req=req,
req_type=RequirementType.INSTALL,
parent_req=None,
pre_built=False,
)
assert results_install[0][1] == Version("1.5")
assert mock_resolve.call_count == 1

# Second: resolve same req as top-level — must NOT use transitive cache
results_toplevel = resolver.resolve(
req=req,
req_type=RequirementType.TOP_LEVEL,
parent_req=None,
pre_built=False,
)
assert mock_resolve.call_count == 2
assert results_toplevel[0][1] == Version("2.0")


@patch("fromager.resolver.find_all_matching_from_provider")
def test_resolve_prebuilt_after_source_uses_separate_cache(
mock_resolve: MagicMock,
Expand Down
Loading