@@ -325,6 +325,94 @@ def after_configured():
325
325
def before_configured ():
326
326
self .cls .__declare_first__ ()
327
327
328
+ def _cls_attr_override_checker (self , cls ):
329
+ """Produce a function that checks if a class has overridden an
330
+ attribute, taking SQLAlchemy-enabled dataclass fields into account.
331
+
332
+ """
333
+ sa_dataclass_metadata_key = _get_immediate_cls_attr (
334
+ cls , "__sa_dataclass_metadata_key__" , None
335
+ )
336
+
337
+ if sa_dataclass_metadata_key is None :
338
+
339
+ def attribute_is_overridden (key , obj ):
340
+ return getattr (cls , key ) is not obj
341
+
342
+ else :
343
+
344
+ all_datacls_fields = {
345
+ f .name : f .metadata [sa_dataclass_metadata_key ]
346
+ for f in util .dataclass_fields (cls )
347
+ if sa_dataclass_metadata_key in f .metadata
348
+ }
349
+ local_datacls_fields = {
350
+ f .name : f .metadata [sa_dataclass_metadata_key ]
351
+ for f in util .local_dataclass_fields (cls )
352
+ if sa_dataclass_metadata_key in f .metadata
353
+ }
354
+
355
+ absent = object ()
356
+
357
+ def attribute_is_overridden (key , obj ):
358
+ # this function likely has some failure modes still if
359
+ # someone is doing a deep mixing of the same attribute
360
+ # name as plain Python attribute vs. dataclass field.
361
+
362
+ ret = local_datacls_fields .get (key , absent )
363
+
364
+ if ret is obj :
365
+ return False
366
+ elif ret is not absent :
367
+ return True
368
+
369
+ ret = getattr (cls , key , obj )
370
+
371
+ if ret is obj :
372
+ return False
373
+ elif ret is not absent :
374
+ return True
375
+
376
+ ret = all_datacls_fields .get (key , absent )
377
+
378
+ if ret is obj :
379
+ return False
380
+ elif ret is not absent :
381
+ return True
382
+
383
+ # can't find another attribute
384
+ return False
385
+
386
+ return attribute_is_overridden
387
+
388
+ def _cls_attr_resolver (self , cls ):
389
+ """produce a function to iterate the "attributes" of a class,
390
+ adjusting for SQLAlchemy fields embedded in dataclass fields.
391
+
392
+ """
393
+ sa_dataclass_metadata_key = _get_immediate_cls_attr (
394
+ cls , "__sa_dataclass_metadata_key__" , None
395
+ )
396
+
397
+ if sa_dataclass_metadata_key is None :
398
+
399
+ def local_attributes_for_class ():
400
+ for name , obj in vars (cls ).items ():
401
+ yield name , obj
402
+
403
+ else :
404
+
405
+ def local_attributes_for_class ():
406
+ for name , obj in vars (cls ).items ():
407
+ yield name , obj
408
+ for field in util .local_dataclass_fields (cls ):
409
+ if sa_dataclass_metadata_key in field .metadata :
410
+ yield field .name , field .metadata [
411
+ sa_dataclass_metadata_key
412
+ ]
413
+
414
+ return local_attributes_for_class
415
+
328
416
def _scan_attributes (self ):
329
417
cls = self .cls
330
418
dict_ = self .dict_
@@ -333,9 +421,9 @@ def _scan_attributes(self):
333
421
table_args = inherited_table_args = None
334
422
tablename = None
335
423
336
- for base in cls . __mro__ :
424
+ attribute_is_overridden = self . _cls_attr_override_checker ( self . cls )
337
425
338
- sa_dataclass_metadata_key = None
426
+ for base in cls . __mro__ :
339
427
340
428
class_mapped = (
341
429
base is not cls
@@ -345,25 +433,14 @@ def _scan_attributes(self):
345
433
)
346
434
)
347
435
348
- if sa_dataclass_metadata_key is None :
349
- sa_dataclass_metadata_key = _get_immediate_cls_attr (
350
- base , "__sa_dataclass_metadata_key__" , None
351
- )
352
-
353
- def attributes_for_class (cls ):
354
- for name , obj in vars (cls ).items ():
355
- yield name , obj
356
- if sa_dataclass_metadata_key :
357
- for field in util .dataclass_fields (cls ):
358
- if sa_dataclass_metadata_key in field .metadata :
359
- yield field .name , field .metadata [
360
- sa_dataclass_metadata_key
361
- ]
436
+ local_attributes_for_class = self ._cls_attr_resolver (base )
362
437
363
438
if not class_mapped and base is not cls :
364
- self ._produce_column_copies (attributes_for_class , base )
439
+ self ._produce_column_copies (
440
+ local_attributes_for_class , attribute_is_overridden
441
+ )
365
442
366
- for name , obj in attributes_for_class ( base ):
443
+ for name , obj in local_attributes_for_class ( ):
367
444
if name == "__mapper_args__" :
368
445
check_decl = _check_declared_props_nocascade (
369
446
obj , name , cls
@@ -471,6 +548,15 @@ def mapper_args_fn():
471
548
else :
472
549
self ._warn_for_decl_attributes (base , name , obj )
473
550
elif name not in dict_ or dict_ [name ] is not obj :
551
+ # here, we are definitely looking at the target class
552
+ # and not a superclass. this is currently a
553
+ # dataclass-only path. if the name is only
554
+ # a dataclass field and isn't in local cls.__dict__,
555
+ # put the object there.
556
+
557
+ # assert that the dataclass-enabled resolver agrees
558
+ # with what we are seeing
559
+ assert not attribute_is_overridden (name , obj )
474
560
dict_ [name ] = obj
475
561
476
562
if inherited_table_args and not tablename :
@@ -489,14 +575,17 @@ def _warn_for_decl_attributes(self, cls, key, c):
489
575
% (key , cls )
490
576
)
491
577
492
- def _produce_column_copies (self , attributes_for_class , base ):
578
+ def _produce_column_copies (
579
+ self , attributes_for_class , attribute_is_overridden
580
+ ):
493
581
cls = self .cls
494
582
dict_ = self .dict_
495
583
column_copies = self .column_copies
496
584
# copy mixin columns to the mapped class
497
- for name , obj in attributes_for_class (base ):
585
+
586
+ for name , obj in attributes_for_class ():
498
587
if isinstance (obj , Column ):
499
- if getattr ( cls , name ) is not obj :
588
+ if attribute_is_overridden ( name , obj ) :
500
589
# if column has been overridden
501
590
# (like by the InstrumentedAttribute of the
502
591
# superclass), skip
0 commit comments