|
10 | 10 | from concurrent.futures import as_completed
|
11 | 11 | from concurrent.futures import ThreadPoolExecutor
|
12 | 12 | from dataclasses import dataclass
|
| 13 | +from enum import Enum |
13 | 14 | from typing import TextIO
|
14 | 15 | from typing import TypeGuard
|
15 | 16 |
|
|
53 | 54 | ]
|
54 | 55 |
|
55 | 56 |
|
| 57 | +class DependencyType(str, Enum): |
| 58 | + SERVICE = "service" |
| 59 | + COMPOSE = "compose" |
| 60 | + |
| 61 | + |
| 62 | +@dataclass(frozen=True, eq=True) |
| 63 | +class DependencyNode: |
| 64 | + name: str |
| 65 | + dependency_type: DependencyType |
| 66 | + |
| 67 | + def __str__(self) -> str: |
| 68 | + return self.name |
| 69 | + |
| 70 | + |
56 | 71 | class DependencyGraph:
|
57 | 72 | def __init__(self) -> None:
|
58 |
| - self.graph: dict[str, set[str]] = dict() |
| 73 | + self.graph: dict[DependencyNode, set[DependencyNode]] = dict() |
59 | 74 |
|
60 |
| - def add_dependency(self, service_name: str) -> None: |
61 |
| - if service_name not in self.graph: |
62 |
| - self.graph[service_name] = set() |
| 75 | + def add_node(self, node: DependencyNode) -> None: |
| 76 | + if node not in self.graph: |
| 77 | + self.graph[node] = set() |
63 | 78 |
|
64 |
| - def add_edge(self, service_name: str, dependency_name: str) -> None: |
65 |
| - # TODO: We should rename services that depend on themselves |
66 |
| - if service_name == dependency_name: |
67 |
| - return |
68 |
| - if service_name not in self.graph: |
69 |
| - self.add_dependency(service_name) |
70 |
| - if dependency_name not in self.graph: |
71 |
| - self.add_dependency(dependency_name) |
| 79 | + def add_edge(self, from_node: DependencyNode, to_node: DependencyNode) -> None: |
| 80 | + if from_node == to_node: |
| 81 | + # TODO: Add a better exception |
| 82 | + raise ValueError("Cannot add an edge from a node to itself") |
| 83 | + if from_node not in self.graph: |
| 84 | + self.add_node(from_node) |
| 85 | + if to_node not in self.graph: |
| 86 | + self.add_node(to_node) |
72 | 87 |
|
73 | 88 | # TODO: Should we check for cycles here?
|
74 | 89 |
|
75 |
| - self.graph[service_name].add(dependency_name) |
| 90 | + self.graph[from_node].add(to_node) |
76 | 91 |
|
77 |
| - def topological_sort(self) -> list[str]: |
| 92 | + def topological_sort(self) -> list[DependencyNode]: |
78 | 93 | in_degree = {service_name: 0 for service_name in self.graph}
|
79 | 94 |
|
80 |
| - for service_name in self.graph.keys(): |
81 |
| - for dependency in self.graph[service_name]: |
82 |
| - in_degree[dependency] += 1 |
| 95 | + for service_node in self.graph.keys(): |
| 96 | + for dependency_node in self.graph[service_node]: |
| 97 | + in_degree[dependency_node] += 1 |
83 | 98 |
|
84 | 99 | queue = deque(
|
85 | 100 | [
|
86 |
| - service_name |
87 |
| - for service_name in self.graph |
88 |
| - if in_degree[service_name] == 0 |
| 101 | + dependency_node |
| 102 | + for dependency_node in self.graph |
| 103 | + if in_degree[dependency_node] == 0 |
89 | 104 | ]
|
90 | 105 | )
|
91 | 106 | topological_order = list()
|
92 | 107 |
|
93 | 108 | while queue:
|
94 |
| - service_name = queue.popleft() |
95 |
| - topological_order.append(service_name) |
| 109 | + service_node = queue.popleft() |
| 110 | + topological_order.append(service_node) |
96 | 111 |
|
97 |
| - for dependency in self.graph[service_name]: |
98 |
| - in_degree[dependency] -= 1 |
99 |
| - if in_degree[dependency] == 0: |
100 |
| - queue.append(dependency) |
| 112 | + for dependency_node in self.graph[service_node]: |
| 113 | + in_degree[dependency_node] -= 1 |
| 114 | + if in_degree[dependency_node] == 0: |
| 115 | + queue.append(dependency_node) |
101 | 116 |
|
102 | 117 | if len(topological_order) != len(self.graph):
|
103 | 118 | # TODO: Add a better exception
|
104 | 119 | raise ValueError("Cycle detected in the dependency graph")
|
105 | 120 |
|
106 | 121 | return topological_order
|
107 | 122 |
|
108 |
| - def get_starting_order(self) -> list[str]: |
| 123 | + def get_starting_order(self) -> list[DependencyNode]: |
109 | 124 | return list(reversed(self.topological_sort()))
|
110 | 125 |
|
111 | 126 |
|
@@ -729,7 +744,18 @@ def _construct_dependency_graph(
|
729 | 744 | # Skip the dependency if it's not in the modes (since it may not be installed and we don't care about it)
|
730 | 745 | if dependency_name not in service_mode_dependencies:
|
731 | 746 | continue
|
732 |
| - dependency_graph.add_edge(service_config.service_name, dependency_name) |
| 747 | + dependency_graph.add_edge( |
| 748 | + DependencyNode( |
| 749 | + name=service_config.service_name, |
| 750 | + dependency_type=DependencyType.SERVICE, |
| 751 | + ), |
| 752 | + DependencyNode( |
| 753 | + name=dependency_name, |
| 754 | + dependency_type=DependencyType.SERVICE |
| 755 | + if _has_remote_config(dependency.remote) |
| 756 | + else DependencyType.COMPOSE, |
| 757 | + ), |
| 758 | + ) |
733 | 759 | if _has_remote_config(dependency.remote):
|
734 | 760 | dependency_config = get_remote_dependency_config(dependency.remote)
|
735 | 761 | _construct_dependency_graph(dependency_config, [dependency.remote.mode])
|
|
0 commit comments