Skip to content

Commit 4ee287c

Browse files
committed
Merge branch 'master' into docker-img-refactor
2 parents 6256eda + 467990e commit 4ee287c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+162
-14893
lines changed

.github/workflows/development.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ jobs:
110110
flake8 --ignore=E203,E722,W503 datajoint \
111111
--count --max-complexity=62 --max-line-length=127 --statistics \
112112
--per-file-ignores='datajoint/diagram.py:C901'
113-
black --required-version '24.2.0' --check -v datajoint tests tests_old
113+
black --required-version '24.2.0' --check -v datajoint tests
114114
codespell:
115115
name: Check for spelling errors
116116
permissions:

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
- Fixed - Added encapsulating double quotes to comply with [DOT language](https://graphviz.org/doc/info/lang.html) - PR [#1177](https://github.com/datajoint/datajoint-python/pull/1177)
66
- Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095)
77
- Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091)
8-
- Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989)
8+
- Added - Ability to specify a list of keys to populate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989)
9+
- Fixed - fixed topological sort [#1057](https://github.com/datajoint/datajoint-python/issues/1057)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)
10+
- Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)
911

1012
### 0.14.2 -- Aug 19, 2024
1113
- Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142)

datajoint/dependencies.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,64 @@
55
from .errors import DataJointError
66

77

8-
def unite_master_parts(lst):
8+
def extract_master(part_table):
99
"""
10-
re-order a list of table names so that part tables immediately follow their master tables without breaking
11-
the topological order.
12-
Without this correction, a simple topological sort may insert other descendants between master and parts.
13-
The input list must be topologically sorted.
14-
:example:
15-
unite_master_parts(
16-
['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) ->
17-
['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`']
10+
given a part table name, return master part. None if not a part table
1811
"""
19-
for i in range(2, len(lst)):
20-
name = lst[i]
21-
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", name)
22-
if match: # name is a part table
23-
master = match.group("master")
24-
for j in range(i - 1, -1, -1):
25-
if lst[j] == master + "`" or lst[j].startswith(master + "__"):
26-
# move from the ith position to the (j+1)th position
27-
lst[j + 1 : i + 1] = [name] + lst[j + 1 : i]
28-
break
29-
return lst
12+
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
13+
return match["master"] + "`" if match else None
14+
15+
16+
def topo_sort(graph):
17+
"""
18+
topological sort of a dependency graph that keeps part tables together with their masters
19+
:return: list of table names in topological order
20+
"""
21+
22+
graph = nx.DiGraph(graph) # make a copy
23+
24+
# collapse alias nodes
25+
alias_nodes = [node for node in graph if node.isdigit()]
26+
for node in alias_nodes:
27+
try:
28+
direct_edge = (
29+
next(x for x in graph.in_edges(node))[0],
30+
next(x for x in graph.out_edges(node))[1],
31+
)
32+
except StopIteration:
33+
pass # a disconnected alias node
34+
else:
35+
graph.add_edge(*direct_edge)
36+
graph.remove_nodes_from(alias_nodes)
37+
38+
# Add parts' dependencies to their masters' dependencies
39+
# to ensure correct topological ordering of the masters.
40+
for part in graph:
41+
# find the part's master
42+
if (master := extract_master(part)) in graph:
43+
for edge in graph.in_edges(part):
44+
parent = edge[0]
45+
if master not in (parent, extract_master(parent)):
46+
# if parent is neither master nor part of master
47+
graph.add_edge(parent, master)
48+
sorted_nodes = list(nx.topological_sort(graph))
49+
50+
# bring parts up to their masters
51+
pos = len(sorted_nodes) - 1
52+
placed = set()
53+
while pos > 1:
54+
part = sorted_nodes[pos]
55+
if (master := extract_master(part)) not in graph or part in placed:
56+
pos -= 1
57+
else:
58+
placed.add(part)
59+
insert_pos = sorted_nodes.index(master) + 1
60+
if pos > insert_pos:
61+
# move the part to the position immediately after its master
62+
del sorted_nodes[pos]
63+
sorted_nodes.insert(insert_pos, part)
64+
65+
return sorted_nodes
3066

3167

3268
class Dependencies(nx.DiGraph):
@@ -131,6 +167,10 @@ def load(self, force=True):
131167
raise DataJointError("DataJoint can only work with acyclic dependencies")
132168
self._loaded = True
133169

170+
def topo_sort(self):
171+
""":return: list of tables names in topological order"""
172+
return topo_sort(self)
173+
134174
def parents(self, table_name, primary=None):
135175
"""
136176
:param table_name: `schema`.`table`
@@ -167,22 +207,14 @@ def descendants(self, full_table_name):
167207
:return: all dependent tables sorted in topological order. Self is included.
168208
"""
169209
self.load(force=False)
170-
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
171-
return unite_master_parts(
172-
[full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
173-
)
210+
nodes = self.subgraph(nx.descendants(self, full_table_name))
211+
return [full_table_name] + nodes.topo_sort()
174212

175213
def ancestors(self, full_table_name):
176214
"""
177215
:param full_table_name: In form `schema`.`table_name`
178216
:return: all dependent tables sorted in topological order. Self is included.
179217
"""
180218
self.load(force=False)
181-
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
182-
return list(
183-
reversed(
184-
unite_master_parts(
185-
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
186-
)
187-
)
188-
)
219+
nodes = self.subgraph(nx.ancestors(self, full_table_name))
220+
return reversed(nodes.topo_sort() + [full_table_name])

datajoint/diagram.py

+14-41
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import networkx as nx
2-
import re
32
import functools
43
import io
54
import logging
65
import inspect
76
from .table import Table
8-
from .dependencies import unite_master_parts
9-
from .user_tables import Manual, Imported, Computed, Lookup, Part
7+
from .dependencies import topo_sort
8+
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
109
from .errors import DataJointError
1110
from .table import lookup_class_name
1211

@@ -27,29 +26,6 @@
2726

2827

2928
logger = logging.getLogger(__name__.split(".")[0])
30-
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
31-
32-
33-
class _AliasNode:
34-
"""
35-
special class to indicate aliased foreign keys
36-
"""
37-
38-
pass
39-
40-
41-
def _get_tier(table_name):
42-
if not table_name.startswith("`"):
43-
return _AliasNode
44-
else:
45-
try:
46-
return next(
47-
tier
48-
for tier in user_table_classes
49-
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
50-
)
51-
except StopIteration:
52-
return None
5329

5430

5531
if not diagram_active:
@@ -59,8 +35,7 @@ class Diagram:
5935
Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz.
6036
6137
To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install
62-
these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and
63-
http://tutorials.datajoint.io/setting-up/datajoint-python.html
38+
these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/
6439
"""
6540

6641
def __init__(self, *args, **kwargs):
@@ -72,19 +47,22 @@ def __init__(self, *args, **kwargs):
7247

7348
class Diagram(nx.DiGraph):
7449
"""
75-
Entity relationship diagram.
50+
Schema diagram showing tables and foreign keys between in the form of a directed
51+
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
7652
7753
Usage:
7854
7955
>>> diag = Diagram(source)
8056
81-
source can be a base table object, a base table class, a schema, or a module that has a schema.
57+
source can be a table object, a table class, a schema, or a module that has a schema.
8258
8359
>>> diag.draw()
8460
8561
draws the diagram using pyplot
8662
8763
diag1 + diag2 - combines the two diagrams.
64+
diag1 - diag2 - difference between diagrams
65+
diag1 * diag2 - intersection of diagrams
8866
diag + n - expands n levels of successors
8967
diag - n - expands n levels of predecessors
9068
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
@@ -94,6 +72,7 @@ class Diagram(nx.DiGraph):
9472
"""
9573

9674
def __init__(self, source, context=None):
75+
9776
if isinstance(source, Diagram):
9877
# copy constructor
9978
self.nodes_to_show = set(source.nodes_to_show)
@@ -154,7 +133,7 @@ def from_sequence(cls, sequence):
154133

155134
def add_parts(self):
156135
"""
157-
Adds to the diagram the part tables of tables already included in the diagram
136+
Adds to the diagram the part tables of all master tables already in the diagram
158137
:return:
159138
"""
160139

@@ -179,16 +158,6 @@ def is_part(part, master):
179158
)
180159
return self
181160

182-
def topological_sort(self):
183-
""":return: list of nodes in topological order"""
184-
return unite_master_parts(
185-
list(
186-
nx.algorithms.dag.topological_sort(
187-
nx.DiGraph(self).subgraph(self.nodes_to_show)
188-
)
189-
)
190-
)
191-
192161
def __add__(self, arg):
193162
"""
194163
:param arg: either another Diagram or a positive integer.
@@ -256,6 +225,10 @@ def __mul__(self, arg):
256225
self.nodes_to_show.intersection_update(arg.nodes_to_show)
257226
return self
258227

228+
def topo_sort(self):
229+
"""return nodes in lexicographical topological order"""
230+
return topo_sort(self)
231+
259232
def _make_graph(self):
260233
"""
261234
Make the self.graph - a graph object ready for drawing

datajoint/external.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .table import Table, FreeTable
99
from .heading import Heading
1010
from .declare import EXTERNAL_TABLE_ROOT
11-
from . import s3
11+
from . import s3, errors
1212
from .utils import safe_write, safe_copy
1313

1414
logger = logging.getLogger(__name__.split(".")[0])
@@ -141,7 +141,12 @@ def _download_buffer(self, external_path):
141141
if self.spec["protocol"] == "s3":
142142
return self.s3.get(external_path)
143143
if self.spec["protocol"] == "file":
144-
return Path(external_path).read_bytes()
144+
try:
145+
return Path(external_path).read_bytes()
146+
except FileNotFoundError:
147+
raise errors.MissingExternalFile(
148+
f"Missing external file {external_path}"
149+
) from None
145150
assert False
146151

147152
def _remove_external_file(self, external_path):

datajoint/schemas.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
import logging
33
import inspect
44
import re
5-
import itertools
65
import collections
6+
import itertools
77
from .connection import conn
8-
from .diagram import Diagram, _get_tier
98
from .settings import config
109
from .errors import DataJointError, AccessError
1110
from .jobs import JobTable
1211
from .external import ExternalMapping
1312
from .heading import Heading
1413
from .utils import user_choice, to_camel_case
15-
from .user_tables import Part, Computed, Imported, Manual, Lookup
14+
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
1615
from .table import lookup_class_name, Log, FreeTable
1716
import types
1817

@@ -413,6 +412,7 @@ def save(self, python_filename=None):
413412
414413
:return: a string containing the body of a complete Python module defining this schema.
415414
"""
415+
self.connection.dependencies.load()
416416
self._assert_exists()
417417
module_count = itertools.count()
418418
# add virtual modules for referenced modules with names vmod0, vmod1, ...
@@ -451,10 +451,8 @@ def replace(s):
451451
).replace("\n", "\n " + indent),
452452
)
453453

454-
diagram = Diagram(self)
455-
body = "\n\n".join(
456-
make_class_definition(table) for table in diagram.topological_sort()
457-
)
454+
tables = self.connection.dependencies.topo_sort()
455+
body = "\n\n".join(make_class_definition(table) for table in tables)
458456
python_code = "\n\n".join(
459457
(
460458
'"""This module was auto-generated by datajoint from an existing schema"""',
@@ -480,11 +478,12 @@ def list_tables(self):
480478
481479
:return: A list of table names from the database schema.
482480
"""
481+
self.connection.dependencies.load()
483482
return [
484483
t
485484
for d, t in (
486485
full_t.replace("`", "").split(".")
487-
for full_t in Diagram(self).topological_sort()
486+
for full_t in self.connection.dependencies.topo_sort()
488487
)
489488
if d == self.database
490489
]
@@ -533,7 +532,6 @@ def __init__(
533532

534533
def list_schemas(connection=None):
535534
"""
536-
537535
:param connection: a dj.Connection object
538536
:return: list of all accessible schemas on the server
539537
"""

datajoint/settings.py

+7
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ def __setitem__(self, key, value):
246246
self._conf[key] = value
247247
else:
248248
raise DataJointError("Validator for {0:s} did not pass".format(key))
249+
valid_logging_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
250+
if key == "loglevel":
251+
if value not in valid_logging_levels:
252+
raise ValueError(f"{'value'} is not a valid logging value")
253+
logger.setLevel(value)
249254

250255

251256
# Load configuration from file
@@ -270,6 +275,7 @@ def __setitem__(self, key, value):
270275
"database.password",
271276
"external.aws_access_key_id",
272277
"external.aws_secret_access_key",
278+
"loglevel",
273279
),
274280
map(
275281
os.getenv,
@@ -279,6 +285,7 @@ def __setitem__(self, key, value):
279285
"DJ_PASS",
280286
"DJ_AWS_ACCESS_KEY_ID",
281287
"DJ_AWS_SECRET_ACCESS_KEY",
288+
"DJ_LOG_LEVEL",
282289
),
283290
),
284291
)

0 commit comments

Comments
 (0)