1
1
# This file is part of the QuestionPy SDK. (https://questionpy.org)
2
2
# The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md.
3
3
# (c) Technische Universität Berlin, innoCampus <[email protected] >
4
+ from collections .abc import Collection
4
5
from typing import Any , Literal , Optional , TypeAlias , TypeVar , cast , overload
5
6
7
+ from pydantic import BeforeValidator
6
8
from pydantic .fields import FieldInfo
7
9
from pydantic_core import PydanticUndefined
8
10
27
29
# TODO: - Add support for numeric inputs (and maybe others?)
28
30
# - Make labels optional
29
31
32
+
30
33
_S = TypeVar ("_S" , bound = str )
31
34
_F = TypeVar ("_F" , bound = FormModel )
32
35
_E = TypeVar ("_E" , bound = OptionEnum )
33
36
34
37
_OneOrMoreConditions : TypeAlias = Condition | list [Condition ]
35
38
_ZeroOrMoreConditions : TypeAlias = _OneOrMoreConditions | None
36
39
40
+ _T = TypeVar ("_T" )
41
+
42
+
43
+ @overload
44
+ def _wrap_in (coll_type : type [set ], value : _T | Collection [_T ] | None ) -> set [_T ]: ...
45
+
46
+
47
+ @overload
48
+ def _wrap_in (coll_type : type [list ], value : _T | Collection [_T ] | None ) -> list [_T ]: ...
49
+
37
50
38
- def _listify ( value : _ZeroOrMoreConditions ) -> list [ Condition ]:
51
+ def _wrap_in ( coll_type : type [ set ] | type [ list ], value : _T | Collection [ _T ] | None ) -> Collection [ _T ]:
39
52
if value is None :
40
- return []
41
- if isinstance (value , list ):
42
- return value
43
- return [value ]
53
+ return coll_type ()
54
+ if isinstance (value , coll_type ):
55
+ return cast (Collection [_T ], value )
56
+ if isinstance (value , Collection ) and not isinstance (value , str ): # (str is a subclass of Collection)
57
+ return coll_type (value )
58
+
59
+ return coll_type ((cast (_T , value ),)) # MyPy gets confused here without the cast.
44
60
45
61
46
62
@overload
@@ -132,8 +148,8 @@ def text_input(
132
148
default = default ,
133
149
placeholder = placeholder ,
134
150
help = help ,
135
- disable_if = _listify ( disable_if ),
136
- hide_if = _listify ( hide_if ),
151
+ disable_if = _wrap_in ( list , disable_if ),
152
+ hide_if = _wrap_in ( list , hide_if ),
137
153
),
138
154
pydantic_field_info = FieldInfo (
139
155
default = None if not required or disable_if or hide_if else PydanticUndefined ,
@@ -231,8 +247,8 @@ def text_area(
231
247
default = default ,
232
248
placeholder = placeholder ,
233
249
help = help ,
234
- disable_if = _listify ( disable_if ),
235
- hide_if = _listify ( hide_if ),
250
+ disable_if = _wrap_in ( list , disable_if ),
251
+ hide_if = _wrap_in ( list , hide_if ),
236
252
),
237
253
pydantic_field_info = FieldInfo (
238
254
default = None if not required or disable_if or hide_if else PydanticUndefined ,
@@ -265,7 +281,12 @@ def static_text(
265
281
StaticTextElement ,
266
282
_StaticElementInfo (
267
283
lambda name : StaticTextElement (
268
- name = name , label = label , text = text , help = help , disable_if = _listify (disable_if ), hide_if = _listify (hide_if )
284
+ name = name ,
285
+ label = label ,
286
+ text = text ,
287
+ help = help ,
288
+ disable_if = _wrap_in (list , disable_if ),
289
+ hide_if = _wrap_in (list , hide_if ),
269
290
)
270
291
),
271
292
)
@@ -360,8 +381,8 @@ def checkbox(
360
381
required = required ,
361
382
help = help ,
362
383
selected = selected ,
363
- disable_if = _listify ( disable_if ),
364
- hide_if = _listify ( hide_if ),
384
+ disable_if = _wrap_in ( list , disable_if ),
385
+ hide_if = _wrap_in ( list , hide_if ),
365
386
),
366
387
pydantic_field_info = FieldInfo (default = False if not required or disable_if or hide_if else PydanticUndefined ),
367
388
)
@@ -451,8 +472,8 @@ def radio_group(
451
472
options = options ,
452
473
required = required ,
453
474
help = help ,
454
- disable_if = _listify ( disable_if ),
455
- hide_if = _listify ( hide_if ),
475
+ disable_if = _wrap_in ( list , disable_if ),
476
+ hide_if = _wrap_in ( list , hide_if ),
456
477
),
457
478
pydantic_field_info = FieldInfo (default = None if not required or disable_if or hide_if else PydanticUndefined ),
458
479
)
@@ -556,9 +577,11 @@ def select(
556
577
557
578
expected_type : type
558
579
default : object
580
+ annotate_with : tuple [object , ...] = ()
559
581
if multiple :
560
582
expected_type = set [enum ] # type: ignore[valid-type]
561
583
default = set () if not required or disable_if or hide_if else PydanticUndefined
584
+ annotate_with = (BeforeValidator (lambda value : _wrap_in (set , value )),)
562
585
elif not required or disable_if or hide_if :
563
586
expected_type = enum | None # type: ignore[assignment]
564
587
default = None
@@ -575,10 +598,11 @@ def select(
575
598
required = required ,
576
599
options = options ,
577
600
help = help ,
578
- disable_if = _listify ( disable_if ),
579
- hide_if = _listify ( hide_if ),
601
+ disable_if = _wrap_in ( list , disable_if ),
602
+ hide_if = _wrap_in ( list , hide_if ),
580
603
),
581
604
pydantic_field_info = FieldInfo (default = default ),
605
+ annotate_with = annotate_with ,
582
606
)
583
607
584
608
@@ -635,7 +659,7 @@ def hidden(value: _S, *, disable_if: _ZeroOrMoreConditions = None, hide_if: _Zer
635
659
_FieldInfo (
636
660
type = Optional [Literal [value ]] if disable_if or hide_if else Literal [value ], # noqa: UP007
637
661
build = lambda name : HiddenElement (
638
- name = name , value = value , disable_if = _listify ( disable_if ), hide_if = _listify ( hide_if )
662
+ name = name , value = value , disable_if = _wrap_in ( list , disable_if ), hide_if = _wrap_in ( list , hide_if )
639
663
),
640
664
pydantic_field_info = FieldInfo (default = None if disable_if or hide_if else PydanticUndefined ),
641
665
),
@@ -715,8 +739,8 @@ def group(
715
739
label = label ,
716
740
elements = model .qpy_form .general ,
717
741
help = help ,
718
- disable_if = _listify ( disable_if ),
719
- hide_if = _listify ( hide_if ),
742
+ disable_if = _wrap_in ( list , disable_if ),
743
+ hide_if = _wrap_in ( list , hide_if ),
720
744
),
721
745
# When the group dict is not provided at all in the form data, we want Pydantic to use the default values
722
746
# for all grouped fields and raise if there are any required ones. Creating the nested model in a
@@ -778,6 +802,7 @@ def repeat(
778
802
button_label = button_label ,
779
803
elements = model .qpy_form .general ,
780
804
),
805
+ annotate_with = (BeforeValidator (lambda value : _wrap_in (list , value )),),
781
806
),
782
807
)
783
808
0 commit comments