Skip to content

Commit 285605d

Browse files
authored
Merge pull request #1184 from dimitri-yatsenko/master
Fix #1103, #1057
2 parents bae66d2 + 477c326 commit 285605d

9 files changed

+126
-138
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
- Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095)
77
- Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091)
88
- Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989)
9+
- Fixed - fixed topological sort [#1057](https://github.com/datajoint/datajoint-python/issues/1057)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)
10+
- Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)
911

1012
### 0.14.2 -- Aug 19, 2024
1113
- Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142)

datajoint/dependencies.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,64 @@
55
from .errors import DataJointError
66

77

8-
def unite_master_parts(lst):
8+
def extract_master(part_table):
99
"""
10-
re-order a list of table names so that part tables immediately follow their master tables without breaking
11-
the topological order.
12-
Without this correction, a simple topological sort may insert other descendants between master and parts.
13-
The input list must be topologically sorted.
14-
:example:
15-
unite_master_parts(
16-
['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) ->
17-
['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`']
10+
given a part table name, return master part. None if not a part table
1811
"""
19-
for i in range(2, len(lst)):
20-
name = lst[i]
21-
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", name)
22-
if match: # name is a part table
23-
master = match.group("master")
24-
for j in range(i - 1, -1, -1):
25-
if lst[j] == master + "`" or lst[j].startswith(master + "__"):
26-
# move from the ith position to the (j+1)th position
27-
lst[j + 1 : i + 1] = [name] + lst[j + 1 : i]
28-
break
29-
return lst
12+
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
13+
return match["master"] + "`" if match else None
14+
15+
16+
def topo_sort(graph):
17+
"""
18+
topological sort of a dependency graph that keeps part tables together with their masters
19+
:return: list of table names in topological order
20+
"""
21+
22+
graph = nx.DiGraph(graph) # make a copy
23+
24+
# collapse alias nodes
25+
alias_nodes = [node for node in graph if node.isdigit()]
26+
for node in alias_nodes:
27+
try:
28+
direct_edge = (
29+
next(x for x in graph.in_edges(node))[0],
30+
next(x for x in graph.out_edges(node))[1],
31+
)
32+
except StopIteration:
33+
pass # a disconnected alias node
34+
else:
35+
graph.add_edge(*direct_edge)
36+
graph.remove_nodes_from(alias_nodes)
37+
38+
# Add parts' dependencies to their masters' dependencies
39+
# to ensure correct topological ordering of the masters.
40+
for part in graph:
41+
# find the part's master
42+
if (master := extract_master(part)) in graph:
43+
for edge in graph.in_edges(part):
44+
parent = edge[0]
45+
if master not in (parent, extract_master(parent)):
46+
# if parent is neither master nor part of master
47+
graph.add_edge(parent, master)
48+
sorted_nodes = list(nx.topological_sort(graph))
49+
50+
# bring parts up to their masters
51+
pos = len(sorted_nodes) - 1
52+
placed = set()
53+
while pos > 1:
54+
part = sorted_nodes[pos]
55+
if (master := extract_master(part)) not in graph or part in placed:
56+
pos -= 1
57+
else:
58+
placed.add(part)
59+
j = sorted_nodes.index(master)
60+
if pos > j + 1:
61+
# move the part to its master
62+
del sorted_nodes[pos]
63+
sorted_nodes.insert(j + 1, part)
64+
65+
return sorted_nodes
3066

3167

3268
class Dependencies(nx.DiGraph):
@@ -131,6 +167,10 @@ def load(self, force=True):
131167
raise DataJointError("DataJoint can only work with acyclic dependencies")
132168
self._loaded = True
133169

170+
def topo_sort(self):
171+
""":return: list of tables names in topological order"""
172+
return topo_sort(self)
173+
134174
def parents(self, table_name, primary=None):
135175
"""
136176
:param table_name: `schema`.`table`
@@ -167,22 +207,14 @@ def descendants(self, full_table_name):
167207
:return: all dependent tables sorted in topological order. Self is included.
168208
"""
169209
self.load(force=False)
170-
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
171-
return unite_master_parts(
172-
[full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
173-
)
210+
nodes = self.subgraph(nx.descendants(self, full_table_name))
211+
return [full_table_name] + nodes.topo_sort()
174212

175213
def ancestors(self, full_table_name):
176214
"""
177215
:param full_table_name: In form `schema`.`table_name`
178216
:return: all dependent tables sorted in topological order. Self is included.
179217
"""
180218
self.load(force=False)
181-
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
182-
return list(
183-
reversed(
184-
unite_master_parts(
185-
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
186-
)
187-
)
188-
)
219+
nodes = self.subgraph(nx.ancestors(self, full_table_name))
220+
return reversed(nodes.topo_sort() + [full_table_name])

datajoint/diagram.py

+14-41
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import networkx as nx
2-
import re
32
import functools
43
import io
54
import logging
65
import inspect
76
from .table import Table
8-
from .dependencies import unite_master_parts
9-
from .user_tables import Manual, Imported, Computed, Lookup, Part
7+
from .dependencies import topo_sort
8+
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
109
from .errors import DataJointError
1110
from .table import lookup_class_name
1211

@@ -27,29 +26,6 @@
2726

2827

2928
logger = logging.getLogger(__name__.split(".")[0])
30-
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
31-
32-
33-
class _AliasNode:
34-
"""
35-
special class to indicate aliased foreign keys
36-
"""
37-
38-
pass
39-
40-
41-
def _get_tier(table_name):
42-
if not table_name.startswith("`"):
43-
return _AliasNode
44-
else:
45-
try:
46-
return next(
47-
tier
48-
for tier in user_table_classes
49-
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
50-
)
51-
except StopIteration:
52-
return None
5329

5430

5531
if not diagram_active:
@@ -59,8 +35,7 @@ class Diagram:
5935
Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz.
6036
6137
To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install
62-
these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and
63-
http://tutorials.datajoint.io/setting-up/datajoint-python.html
38+
these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/
6439
"""
6540

6641
def __init__(self, *args, **kwargs):
@@ -72,19 +47,22 @@ def __init__(self, *args, **kwargs):
7247

7348
class Diagram(nx.DiGraph):
7449
"""
75-
Entity relationship diagram.
50+
Schema diagram showing tables and foreign keys between in the form of a directed
51+
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
7652
7753
Usage:
7854
7955
>>> diag = Diagram(source)
8056
81-
source can be a base table object, a base table class, a schema, or a module that has a schema.
57+
source can be a table object, a table class, a schema, or a module that has a schema.
8258
8359
>>> diag.draw()
8460
8561
draws the diagram using pyplot
8662
8763
diag1 + diag2 - combines the two diagrams.
64+
diag1 - diag2 - difference between diagrams
65+
diag1 * diag2 - intersection of diagrams
8866
diag + n - expands n levels of successors
8967
diag - n - expands n levels of predecessors
9068
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
@@ -94,6 +72,7 @@ class Diagram(nx.DiGraph):
9472
"""
9573

9674
def __init__(self, source, context=None):
75+
9776
if isinstance(source, Diagram):
9877
# copy constructor
9978
self.nodes_to_show = set(source.nodes_to_show)
@@ -154,7 +133,7 @@ def from_sequence(cls, sequence):
154133

155134
def add_parts(self):
156135
"""
157-
Adds to the diagram the part tables of tables already included in the diagram
136+
Adds to the diagram the part tables of all master tables already in the diagram
158137
:return:
159138
"""
160139

@@ -179,16 +158,6 @@ def is_part(part, master):
179158
)
180159
return self
181160

182-
def topological_sort(self):
183-
""":return: list of nodes in topological order"""
184-
return unite_master_parts(
185-
list(
186-
nx.algorithms.dag.topological_sort(
187-
nx.DiGraph(self).subgraph(self.nodes_to_show)
188-
)
189-
)
190-
)
191-
192161
def __add__(self, arg):
193162
"""
194163
:param arg: either another Diagram or a positive integer.
@@ -256,6 +225,10 @@ def __mul__(self, arg):
256225
self.nodes_to_show.intersection_update(arg.nodes_to_show)
257226
return self
258227

228+
def topo_sort(self):
229+
"""return nodes in lexicographical topological order"""
230+
return topo_sort(self)
231+
259232
def _make_graph(self):
260233
"""
261234
Make the self.graph - a graph object ready for drawing

datajoint/schemas.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
import logging
33
import inspect
44
import re
5-
import itertools
65
import collections
6+
import itertools
77
from .connection import conn
8-
from .diagram import Diagram, _get_tier
98
from .settings import config
109
from .errors import DataJointError, AccessError
1110
from .jobs import JobTable
1211
from .external import ExternalMapping
1312
from .heading import Heading
1413
from .utils import user_choice, to_camel_case
15-
from .user_tables import Part, Computed, Imported, Manual, Lookup
14+
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
1615
from .table import lookup_class_name, Log, FreeTable
1716
import types
1817

@@ -413,6 +412,7 @@ def save(self, python_filename=None):
413412
414413
:return: a string containing the body of a complete Python module defining this schema.
415414
"""
415+
self.connection.dependencies.load()
416416
self._assert_exists()
417417
module_count = itertools.count()
418418
# add virtual modules for referenced modules with names vmod0, vmod1, ...
@@ -451,10 +451,8 @@ def replace(s):
451451
).replace("\n", "\n " + indent),
452452
)
453453

454-
diagram = Diagram(self)
455-
body = "\n\n".join(
456-
make_class_definition(table) for table in diagram.topological_sort()
457-
)
454+
tables = self.connection.dependencies.topo_sort()
455+
body = "\n\n".join(make_class_definition(table) for table in tables)
458456
python_code = "\n\n".join(
459457
(
460458
'"""This module was auto-generated by datajoint from an existing schema"""',
@@ -480,11 +478,12 @@ def list_tables(self):
480478
481479
:return: A list of table names from the database schema.
482480
"""
481+
self.connection.dependencies.load()
483482
return [
484483
t
485484
for d, t in (
486485
full_t.replace("`", "").split(".")
487-
for full_t in Diagram(self).topological_sort()
486+
for full_t in self.connection.dependencies.topo_sort()
488487
)
489488
if d == self.database
490489
]
@@ -533,7 +532,6 @@ def __init__(
533532

534533
def list_schemas(connection=None):
535534
"""
536-
537535
:param connection: a dj.Connection object
538536
:return: list of all accessible schemas on the server
539537
"""

datajoint/table.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False):
196196

197197
def children(self, primary=None, as_objects=False, foreign_key_info=False):
198198
"""
199-
200199
:param primary: if None, then all children are returned. If True, then only foreign keys composed of
201200
primary key attributes are considered. If False, return foreign keys including at least one
202201
secondary attribute.
@@ -218,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False):
218217

219218
def descendants(self, as_objects=False):
220219
"""
221-
222220
:param as_objects: False - a list of table names; True - a list of table objects.
223221
:return: list of tables descendants in topological order.
224222
"""
@@ -230,7 +228,6 @@ def descendants(self, as_objects=False):
230228

231229
def ancestors(self, as_objects=False):
232230
"""
233-
234231
:param as_objects: False - a list of table names; True - a list of table objects.
235232
:return: list of tables ancestors in topological order.
236233
"""
@@ -246,6 +243,7 @@ def parts(self, as_objects=False):
246243
247244
:param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
248245
"""
246+
self.connection.dependencies.load(force=False)
249247
nodes = [
250248
node
251249
for node in self.connection.dependencies.nodes
@@ -427,7 +425,8 @@ def insert(
427425
self.connection.query(query)
428426
return
429427

430-
field_list = [] # collects the field list from first row (passed by reference)
428+
# collects the field list from first row (passed by reference)
429+
field_list = []
431430
rows = list(
432431
self.__make_row_to_insert(row, field_list, ignore_extra_fields)
433432
for row in rows
@@ -520,7 +519,8 @@ def cascade(table):
520519
delete_count = table.delete_quick(get_count=True)
521520
except IntegrityError as error:
522521
match = foreign_key_error_regexp.match(error.args[0]).groupdict()
523-
if "`.`" not in match["child"]: # if schema name missing, use table
522+
# if schema name missing, use table
523+
if "`.`" not in match["child"]:
524524
match["child"] = "{}.{}".format(
525525
table.full_table_name.split(".")[0], match["child"]
526526
)
@@ -964,7 +964,8 @@ def lookup_class_name(name, context, depth=3):
964964
while nodes:
965965
node = nodes.pop(0)
966966
for member_name, member in node["context"].items():
967-
if not member_name.startswith("_"): # skip IPython's implicit variables
967+
# skip IPython's implicit variables
968+
if not member_name.startswith("_"):
968969
if inspect.isclass(member) and issubclass(member, Table):
969970
if member.full_table_name == name: # found it!
970971
return ".".join([node["context_name"], member_name]).lstrip(".")

0 commit comments

Comments
 (0)