|
5 | 5 | from .errors import DataJointError
|
6 | 6 |
|
7 | 7 |
|
8 |
| -def unite_master_parts(lst): |
| 8 | +def extract_master(part_table): |
9 | 9 | """
|
10 |
| - re-order a list of table names so that part tables immediately follow their master tables without breaking |
11 |
| - the topological order. |
12 |
| - Without this correction, a simple topological sort may insert other descendants between master and parts. |
13 |
| - The input list must be topologically sorted. |
14 |
| - :example: |
15 |
| - unite_master_parts( |
16 |
| - ['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) -> |
17 |
| - ['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`'] |
| 10 | + given a part table name, return master part. None if not a part table |
18 | 11 | """
|
19 |
| - for i in range(2, len(lst)): |
20 |
| - name = lst[i] |
21 |
| - match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", name) |
22 |
| - if match: # name is a part table |
23 |
| - master = match.group("master") |
24 |
| - for j in range(i - 1, -1, -1): |
25 |
| - if lst[j] == master + "`" or lst[j].startswith(master + "__"): |
26 |
| - # move from the ith position to the (j+1)th position |
27 |
| - lst[j + 1 : i + 1] = [name] + lst[j + 1 : i] |
28 |
| - break |
29 |
| - return lst |
| 12 | + match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table) |
| 13 | + return match["master"] + "`" if match else None |
| 14 | + |
| 15 | + |
| 16 | +def topo_sort(graph): |
| 17 | + """ |
| 18 | + topological sort of a dependency graph that keeps part tables together with their masters |
| 19 | + :return: list of table names in topological order |
| 20 | + """ |
| 21 | + |
| 22 | + graph = nx.DiGraph(graph) # make a copy |
| 23 | + |
| 24 | + # collapse alias nodes |
| 25 | + alias_nodes = [node for node in graph if node.isdigit()] |
| 26 | + for node in alias_nodes: |
| 27 | + try: |
| 28 | + direct_edge = ( |
| 29 | + next(x for x in graph.in_edges(node))[0], |
| 30 | + next(x for x in graph.out_edges(node))[1], |
| 31 | + ) |
| 32 | + except StopIteration: |
| 33 | + pass # a disconnected alias node |
| 34 | + else: |
| 35 | + graph.add_edge(*direct_edge) |
| 36 | + graph.remove_nodes_from(alias_nodes) |
| 37 | + |
| 38 | + # Add parts' dependencies to their masters' dependencies |
| 39 | + # to ensure correct topological ordering of the masters. |
| 40 | + for part in graph: |
| 41 | + # find the part's master |
| 42 | + if (master := extract_master(part)) in graph: |
| 43 | + for edge in graph.in_edges(part): |
| 44 | + parent = edge[0] |
| 45 | + if master not in (parent, extract_master(parent)): |
| 46 | + # if parent is neither master nor part of master |
| 47 | + graph.add_edge(parent, master) |
| 48 | + sorted_nodes = list(nx.topological_sort(graph)) |
| 49 | + |
| 50 | + # bring parts up to their masters |
| 51 | + pos = len(sorted_nodes) - 1 |
| 52 | + placed = set() |
| 53 | + while pos > 1: |
| 54 | + part = sorted_nodes[pos] |
| 55 | + if (master := extract_master(part)) not in graph or part in placed: |
| 56 | + pos -= 1 |
| 57 | + else: |
| 58 | + placed.add(part) |
| 59 | + j = sorted_nodes.index(master) |
| 60 | + if pos > j + 1: |
| 61 | + # move the part to its master |
| 62 | + del sorted_nodes[pos] |
| 63 | + sorted_nodes.insert(j + 1, part) |
| 64 | + |
| 65 | + return sorted_nodes |
30 | 66 |
|
31 | 67 |
|
32 | 68 | class Dependencies(nx.DiGraph):
|
@@ -131,6 +167,10 @@ def load(self, force=True):
|
131 | 167 | raise DataJointError("DataJoint can only work with acyclic dependencies")
|
132 | 168 | self._loaded = True
|
133 | 169 |
|
| 170 | + def topo_sort(self): |
| 171 | + """:return: list of tables names in topological order""" |
| 172 | + return topo_sort(self) |
| 173 | + |
134 | 174 | def parents(self, table_name, primary=None):
|
135 | 175 | """
|
136 | 176 | :param table_name: `schema`.`table`
|
@@ -167,22 +207,14 @@ def descendants(self, full_table_name):
|
167 | 207 | :return: all dependent tables sorted in topological order. Self is included.
|
168 | 208 | """
|
169 | 209 | self.load(force=False)
|
170 |
| - nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) |
171 |
| - return unite_master_parts( |
172 |
| - [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) |
173 |
| - ) |
| 210 | + nodes = self.subgraph(nx.descendants(self, full_table_name)) |
| 211 | + return [full_table_name] + nodes.topo_sort() |
174 | 212 |
|
175 | 213 | def ancestors(self, full_table_name):
|
176 | 214 | """
|
177 | 215 | :param full_table_name: In form `schema`.`table_name`
|
178 | 216 | :return: all dependent tables sorted in topological order. Self is included.
|
179 | 217 | """
|
180 | 218 | self.load(force=False)
|
181 |
| - nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) |
182 |
| - return list( |
183 |
| - reversed( |
184 |
| - unite_master_parts( |
185 |
| - list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] |
186 |
| - ) |
187 |
| - ) |
188 |
| - ) |
| 219 | + nodes = self.subgraph(nx.ancestors(self, full_table_name)) |
| 220 | + return reversed(nodes.topo_sort() + [full_table_name]) |
0 commit comments