Skip to content

Commit 9205e91

Browse files
committed
Fill-out dataclass-related attr resolution
Fixed issue where mixin attribute rules were not taking effect correctly for attributes pulled from dataclasses using the approach added in sqlalchemy#5745. Fixes: sqlalchemy#5876 Change-Id: I45099a42de1d9611791e72250fe0edc69bed684c
1 parent 57db20a commit 9205e91

File tree

5 files changed

+358
-28
lines changed

5 files changed

+358
-28
lines changed

lib/sqlalchemy/orm/decl_base.py

Lines changed: 110 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,94 @@ def after_configured():
325325
def before_configured():
326326
self.cls.__declare_first__()
327327

328+
def _cls_attr_override_checker(self, cls):
329+
"""Produce a function that checks if a class has overridden an
330+
attribute, taking SQLAlchemy-enabled dataclass fields into account.
331+
332+
"""
333+
sa_dataclass_metadata_key = _get_immediate_cls_attr(
334+
cls, "__sa_dataclass_metadata_key__", None
335+
)
336+
337+
if sa_dataclass_metadata_key is None:
338+
339+
def attribute_is_overridden(key, obj):
340+
return getattr(cls, key) is not obj
341+
342+
else:
343+
344+
all_datacls_fields = {
345+
f.name: f.metadata[sa_dataclass_metadata_key]
346+
for f in util.dataclass_fields(cls)
347+
if sa_dataclass_metadata_key in f.metadata
348+
}
349+
local_datacls_fields = {
350+
f.name: f.metadata[sa_dataclass_metadata_key]
351+
for f in util.local_dataclass_fields(cls)
352+
if sa_dataclass_metadata_key in f.metadata
353+
}
354+
355+
absent = object()
356+
357+
def attribute_is_overridden(key, obj):
358+
# this function likely has some failure modes still if
359+
# someone is doing a deep mixing of the same attribute
360+
# name as plain Python attribute vs. dataclass field.
361+
362+
ret = local_datacls_fields.get(key, absent)
363+
364+
if ret is obj:
365+
return False
366+
elif ret is not absent:
367+
return True
368+
369+
ret = getattr(cls, key, obj)
370+
371+
if ret is obj:
372+
return False
373+
elif ret is not absent:
374+
return True
375+
376+
ret = all_datacls_fields.get(key, absent)
377+
378+
if ret is obj:
379+
return False
380+
elif ret is not absent:
381+
return True
382+
383+
# can't find another attribute
384+
return False
385+
386+
return attribute_is_overridden
387+
388+
def _cls_attr_resolver(self, cls):
389+
"""produce a function to iterate the "attributes" of a class,
390+
adjusting for SQLAlchemy fields embedded in dataclass fields.
391+
392+
"""
393+
sa_dataclass_metadata_key = _get_immediate_cls_attr(
394+
cls, "__sa_dataclass_metadata_key__", None
395+
)
396+
397+
if sa_dataclass_metadata_key is None:
398+
399+
def local_attributes_for_class():
400+
for name, obj in vars(cls).items():
401+
yield name, obj
402+
403+
else:
404+
405+
def local_attributes_for_class():
406+
for name, obj in vars(cls).items():
407+
yield name, obj
408+
for field in util.local_dataclass_fields(cls):
409+
if sa_dataclass_metadata_key in field.metadata:
410+
yield field.name, field.metadata[
411+
sa_dataclass_metadata_key
412+
]
413+
414+
return local_attributes_for_class
415+
328416
def _scan_attributes(self):
329417
cls = self.cls
330418
dict_ = self.dict_
@@ -333,9 +421,9 @@ def _scan_attributes(self):
333421
table_args = inherited_table_args = None
334422
tablename = None
335423

336-
for base in cls.__mro__:
424+
attribute_is_overridden = self._cls_attr_override_checker(self.cls)
337425

338-
sa_dataclass_metadata_key = None
426+
for base in cls.__mro__:
339427

340428
class_mapped = (
341429
base is not cls
@@ -345,25 +433,14 @@ def _scan_attributes(self):
345433
)
346434
)
347435

348-
if sa_dataclass_metadata_key is None:
349-
sa_dataclass_metadata_key = _get_immediate_cls_attr(
350-
base, "__sa_dataclass_metadata_key__", None
351-
)
352-
353-
def attributes_for_class(cls):
354-
for name, obj in vars(cls).items():
355-
yield name, obj
356-
if sa_dataclass_metadata_key:
357-
for field in util.dataclass_fields(cls):
358-
if sa_dataclass_metadata_key in field.metadata:
359-
yield field.name, field.metadata[
360-
sa_dataclass_metadata_key
361-
]
436+
local_attributes_for_class = self._cls_attr_resolver(base)
362437

363438
if not class_mapped and base is not cls:
364-
self._produce_column_copies(attributes_for_class, base)
439+
self._produce_column_copies(
440+
local_attributes_for_class, attribute_is_overridden
441+
)
365442

366-
for name, obj in attributes_for_class(base):
443+
for name, obj in local_attributes_for_class():
367444
if name == "__mapper_args__":
368445
check_decl = _check_declared_props_nocascade(
369446
obj, name, cls
@@ -471,6 +548,15 @@ def mapper_args_fn():
471548
else:
472549
self._warn_for_decl_attributes(base, name, obj)
473550
elif name not in dict_ or dict_[name] is not obj:
551+
# here, we are definitely looking at the target class
552+
# and not a superclass. this is currently a
553+
# dataclass-only path. if the name is only
554+
# a dataclass field and isn't in local cls.__dict__,
555+
# put the object there.
556+
557+
# assert that the dataclass-enabled resolver agrees
558+
# with what we are seeing
559+
assert not attribute_is_overridden(name, obj)
474560
dict_[name] = obj
475561

476562
if inherited_table_args and not tablename:
@@ -489,14 +575,17 @@ def _warn_for_decl_attributes(self, cls, key, c):
489575
% (key, cls)
490576
)
491577

492-
def _produce_column_copies(self, attributes_for_class, base):
578+
def _produce_column_copies(
579+
self, attributes_for_class, attribute_is_overridden
580+
):
493581
cls = self.cls
494582
dict_ = self.dict_
495583
column_copies = self.column_copies
496584
# copy mixin columns to the mapped class
497-
for name, obj in attributes_for_class(base):
585+
586+
for name, obj in attributes_for_class():
498587
if isinstance(obj, Column):
499-
if getattr(cls, name) is not obj:
588+
if attribute_is_overridden(name, obj):
500589
# if column has been overridden
501590
# (like by the InstrumentedAttribute of the
502591
# superclass), skip

lib/sqlalchemy/testing/fixtures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ class DeclarativeBasic(object):
552552
metaclass=FindFixtureDeclarative,
553553
cls=DeclarativeBasic,
554554
)
555+
555556
cls.DeclarativeBasic = _DeclBase
556557

557558
# sets up cls.Basic which is helpful for things like composite

lib/sqlalchemy/util/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from .compat import iterbytes # noqa
6767
from .compat import itertools_filter # noqa
6868
from .compat import itertools_filterfalse # noqa
69+
from .compat import local_dataclass_fields # noqa
6970
from .compat import namedtuple # noqa
7071
from .compat import next # noqa
7172
from .compat import nullcontext # noqa

lib/sqlalchemy/util/compat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,17 +425,37 @@ def inspect_formatargspec(*spec, **kw):
425425
import dataclasses
426426

427427
def dataclass_fields(cls):
428+
"""Return a sequence of all dataclasses.Field objects associated
429+
with a class."""
430+
428431
if dataclasses.is_dataclass(cls):
429432
return dataclasses.fields(cls)
430433
else:
431434
return []
432435

436+
def local_dataclass_fields(cls):
437+
"""Return a sequence of all dataclasses.Field objects associated with
438+
a class, excluding those that originate from a superclass."""
439+
440+
if dataclasses.is_dataclass(cls):
441+
super_fields = set()
442+
for sup in cls.__bases__:
443+
super_fields.update(dataclass_fields(sup))
444+
return [
445+
f for f in dataclasses.fields(cls) if f not in super_fields
446+
]
447+
else:
448+
return []
449+
433450

434451
else:
435452

436453
def dataclass_fields(cls):
437454
return []
438455

456+
def local_dataclass_fields(cls):
457+
return []
458+
439459

440460
def raise_from_cause(exception, exc_info=None):
441461
r"""legacy. use raise\_()"""

0 commit comments

Comments
 (0)