Skip to content

Commit e58d2a2

Browse files
authored
refactor(internal): simplify and cleanup dependency resolution (#1691)
Summary: - refactor: add an OS.interpreter constructor - refactor: add an Arch.interpreter constructor - test: add a specializations tests - refactor: simplify sorting implementation - refactor: introduce `__str__` to enums Minor cleanup that will be useful to make this codebase work for multiple Python versions. Work towards #1643.
1 parent 8ecad9d commit e58d2a2

File tree

3 files changed

+113
-50
lines changed

3 files changed

+113
-50
lines changed

python/pip_install/tools/wheel_installer/wheel.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ class OS(Enum):
3636
darwin = osx
3737
win32 = windows
3838

39+
@classmethod
40+
def interpreter(cls) -> "OS":
41+
"Return the interpreter operating system."
42+
return cls[sys.platform.lower()]
43+
44+
def __str__(self) -> str:
45+
return self.name.lower()
46+
3947

4048
class Arch(Enum):
4149
x86_64 = 1
@@ -50,6 +58,31 @@ class Arch(Enum):
5058
x86 = x86_32
5159
ppc64le = ppc
5260

61+
@classmethod
62+
def interpreter(cls) -> "OS":
63+
"Return the currently running interpreter architecture."
64+
# FIXME @aignas 2023-12-13: Hermetic toolchain on Windows 3.11.6
65+
# is returning an empty string here, so lets default to x86_64
66+
return cls[platform.machine().lower() or "x86_64"]
67+
68+
def __str__(self) -> str:
69+
return self.name.lower()
70+
71+
72+
def _as_int(value: Optional[Union[OS, Arch]]) -> int:
73+
"""Convert one of the enums above to an int for easier sorting algorithms.
74+
75+
Args:
76+
value: The value of an enum or None.
77+
78+
Returns:
79+
-1 if we get None, otherwise, the numeric value of the given enum.
80+
"""
81+
if value is None:
82+
return -1
83+
84+
return int(value.value)
85+
5386

5487
@dataclass(frozen=True)
5588
class Platform:
@@ -77,14 +110,7 @@ def host(cls) -> List["Platform"]:
77110
A list of parsed values which makes the signature the same as
78111
`Platform.all` and `Platform.from_string`.
79112
"""
80-
return [
81-
cls(
82-
os=OS[sys.platform.lower()],
83-
# FIXME @aignas 2023-12-13: Hermetic toolchain on Windows 3.11.6
84-
# is returning an empty string here, so lets default to x86_64
85-
arch=Arch[platform.machine().lower() or "x86_64"],
86-
)
87-
]
113+
return [cls(os=OS.interpreter(), arch=Arch.interpreter())]
88114

89115
def all_specializations(self) -> Iterator["Platform"]:
90116
"""Return the platform itself and all its unambiguous specializations.
@@ -102,30 +128,19 @@ def __lt__(self, other: Any) -> bool:
102128
if not isinstance(other, Platform) or other is None:
103129
raise ValueError(f"cannot compare {other} with Platform")
104130

105-
if self.arch is None and other.arch is not None:
106-
return True
107-
108-
if self.arch is not None and other.arch is None:
109-
return True
131+
self_arch, self_os = _as_int(self.arch), _as_int(self.os)
132+
other_arch, other_os = _as_int(other.arch), _as_int(other.os)
110133

111-
# Here we ensure that we sort by OS before sorting by arch
112-
113-
if self.arch is None and other.arch is None:
114-
return self.os.value < other.os.value
115-
116-
if self.os.value < other.os.value:
117-
return True
118-
119-
if self.os.value == other.os.value:
120-
return self.arch.value < other.arch.value
121-
122-
return False
134+
if self_os == other_os:
135+
return self_arch < other_arch
136+
else:
137+
return self_os < other_os
123138

124139
def __str__(self) -> str:
125140
if self.arch is None:
126-
return f"@platforms//os:{self.os.name.lower()}"
141+
return f"@platforms//os:{self.os}"
127142

128-
return self.os.name.lower() + "_" + self.arch.name.lower()
143+
return f"{self.os}_{self.arch}"
129144

130145
@classmethod
131146
def from_string(cls, platform: Union[str, List[str]]) -> List["Platform"]:

python/pip_install/tools/wheel_installer/wheel_installer_test.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,28 +101,5 @@ def test_wheel_exists(self) -> None:
101101
self.assertEqual(want, metadata_file_content)
102102

103103

104-
class TestWheelPlatform(unittest.TestCase):
105-
def test_wheel_os_alias(self):
106-
self.assertEqual("OS.osx", str(wheel.OS.osx))
107-
self.assertEqual(str(wheel.OS.darwin), str(wheel.OS.osx))
108-
109-
def test_wheel_arch_alias(self):
110-
self.assertEqual("Arch.x86_64", str(wheel.Arch.x86_64))
111-
self.assertEqual(str(wheel.Arch.amd64), str(wheel.Arch.x86_64))
112-
113-
def test_wheel_platform_alias(self):
114-
give = wheel.Platform(
115-
os=wheel.OS.darwin,
116-
arch=wheel.Arch.amd64,
117-
)
118-
alias = wheel.Platform(
119-
os=wheel.OS.osx,
120-
arch=wheel.Arch.x86_64,
121-
)
122-
123-
self.assertEqual("osx_x86_64", str(give))
124-
self.assertEqual(str(alias), str(give))
125-
126-
127104
if __name__ == "__main__":
128105
unittest.main()

python/pip_install/tools/wheel_installer/wheel_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from random import shuffle
23

34
from python.pip_install.tools.wheel_installer import wheel
45

@@ -169,6 +170,76 @@ def test_can_get_all_for_os(self):
169170
self.assertEqual(5, len(linuxes))
170171
self.assertEqual(linuxes, wheel.Platform.from_string("linux_*"))
171172

173+
def test_linux_specializations(self):
174+
any_linux = wheel.Platform(os=wheel.OS.linux)
175+
all_specializations = list(any_linux.all_specializations())
176+
want = [
177+
wheel.Platform(os=wheel.OS.linux, arch=None),
178+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.x86_64),
179+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.x86_32),
180+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.aarch64),
181+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.ppc),
182+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.s390x),
183+
]
184+
self.assertEqual(want, all_specializations)
185+
186+
def test_osx_specializations(self):
187+
any_osx = wheel.Platform(os=wheel.OS.osx)
188+
all_specializations = list(any_osx.all_specializations())
189+
# NOTE @aignas 2024-01-14: even though in practice we would only have
190+
# Python on osx aarch64 and osx x86_64, we return all arch posibilities
191+
# to make the code simpler.
192+
want = [
193+
wheel.Platform(os=wheel.OS.osx, arch=None),
194+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.x86_64),
195+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.x86_32),
196+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.aarch64),
197+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.ppc),
198+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.s390x),
199+
]
200+
self.assertEqual(want, all_specializations)
201+
202+
def test_platform_sort(self):
203+
platforms = [
204+
wheel.Platform(os=wheel.OS.linux, arch=None),
205+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.x86_64),
206+
wheel.Platform(os=wheel.OS.osx, arch=None),
207+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.x86_64),
208+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.aarch64),
209+
]
210+
shuffle(platforms)
211+
platforms.sort()
212+
want = [
213+
wheel.Platform(os=wheel.OS.linux, arch=None),
214+
wheel.Platform(os=wheel.OS.linux, arch=wheel.Arch.x86_64),
215+
wheel.Platform(os=wheel.OS.osx, arch=None),
216+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.x86_64),
217+
wheel.Platform(os=wheel.OS.osx, arch=wheel.Arch.aarch64),
218+
]
219+
220+
self.assertEqual(want, platforms)
221+
222+
def test_wheel_os_alias(self):
223+
self.assertEqual("osx", str(wheel.OS.osx))
224+
self.assertEqual(str(wheel.OS.darwin), str(wheel.OS.osx))
225+
226+
def test_wheel_arch_alias(self):
227+
self.assertEqual("x86_64", str(wheel.Arch.x86_64))
228+
self.assertEqual(str(wheel.Arch.amd64), str(wheel.Arch.x86_64))
229+
230+
def test_wheel_platform_alias(self):
231+
give = wheel.Platform(
232+
os=wheel.OS.darwin,
233+
arch=wheel.Arch.amd64,
234+
)
235+
alias = wheel.Platform(
236+
os=wheel.OS.osx,
237+
arch=wheel.Arch.x86_64,
238+
)
239+
240+
self.assertEqual("osx_x86_64", str(give))
241+
self.assertEqual(str(alias), str(give))
242+
172243

173244
if __name__ == "__main__":
174245
unittest.main()

0 commit comments

Comments
 (0)