29
29
import inspect
30
30
import struct
31
31
from typing import (
32
- Any ,
33
32
Awaitable ,
34
33
Callable ,
34
+ Generic ,
35
35
Dict ,
36
36
List ,
37
37
Optional ,
38
38
Type ,
39
+ TypeVar ,
39
40
Union ,
40
41
TYPE_CHECKING ,
41
42
)
42
43
43
44
from pyee import EventEmitter
44
45
45
46
from bumble import utils
46
- from bumble .core import UUID , name_or_number , ProtocolError
47
+ from bumble .core import UUID , name_or_number , InvalidOperationError , ProtocolError
47
48
from bumble .hci import HCI_Object , key_with_value
48
49
from bumble .colors import color
49
50
51
+ # -----------------------------------------------------------------------------
52
+ # Typing
53
+ # -----------------------------------------------------------------------------
50
54
if TYPE_CHECKING :
51
55
from bumble .device import Connection
52
56
57
+ _T = TypeVar ('_T' )
58
+
53
59
# -----------------------------------------------------------------------------
54
60
# Constants
55
61
# -----------------------------------------------------------------------------
@@ -748,7 +754,7 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
748
754
749
755
750
756
# -----------------------------------------------------------------------------
751
- class AttributeValue :
757
+ class AttributeValue ( Generic [ _T ]) :
752
758
'''
753
759
Attribute value where reading and/or writing is delegated to functions
754
760
passed as arguments to the constructor.
@@ -757,33 +763,34 @@ class AttributeValue:
757
763
def __init__ (
758
764
self ,
759
765
read : Union [
760
- Callable [[Optional [Connection ]], Any ],
761
- Callable [[Optional [Connection ]], Awaitable [Any ]],
766
+ Callable [[Optional [Connection ]], _T ],
767
+ Callable [[Optional [Connection ]], Awaitable [_T ]],
762
768
None ,
763
769
] = None ,
764
770
write : Union [
765
- Callable [[Optional [Connection ], Any ], None ],
766
- Callable [[Optional [Connection ], Any ], Awaitable [None ]],
771
+ Callable [[Optional [Connection ], _T ], None ],
772
+ Callable [[Optional [Connection ], _T ], Awaitable [None ]],
767
773
None ,
768
774
] = None ,
769
775
):
770
776
self ._read = read
771
777
self ._write = write
772
778
773
- def read (self , connection : Optional [Connection ]) -> Union [bytes , Awaitable [bytes ]]:
774
- return self ._read (connection ) if self ._read else b''
779
+ def read (self , connection : Optional [Connection ]) -> Union [_T , Awaitable [_T ]]:
780
+ if self ._read is None :
781
+ raise InvalidOperationError ('AttributeValue has no read function' )
782
+ return self ._read (connection )
775
783
776
784
def write (
777
- self , connection : Optional [Connection ], value : bytes
785
+ self , connection : Optional [Connection ], value : _T
778
786
) -> Union [Awaitable [None ], None ]:
779
- if self ._write :
780
- return self ._write (connection , value )
781
-
782
- return None
787
+ if self ._write is None :
788
+ raise InvalidOperationError ('AttributeValue has no write function' )
789
+ return self ._write (connection , value )
783
790
784
791
785
792
# -----------------------------------------------------------------------------
786
- class Attribute (EventEmitter ):
793
+ class Attribute (EventEmitter , Generic [ _T ] ):
787
794
class Permissions (enum .IntFlag ):
788
795
READABLE = 0x01
789
796
WRITEABLE = 0x02
@@ -822,13 +829,13 @@ def from_string(cls, permissions_str: str) -> Attribute.Permissions:
822
829
READ_REQUIRES_AUTHORIZATION = Permissions .READ_REQUIRES_AUTHORIZATION
823
830
WRITE_REQUIRES_AUTHORIZATION = Permissions .WRITE_REQUIRES_AUTHORIZATION
824
831
825
- value : Any
832
+ value : Union [ AttributeValue [ _T ], _T , None ]
826
833
827
834
def __init__ (
828
835
self ,
829
836
attribute_type : Union [str , bytes , UUID ],
830
837
permissions : Union [str , Attribute .Permissions ],
831
- value : Any = b'' ,
838
+ value : Union [ AttributeValue [ _T ], _T , None ] = None ,
832
839
) -> None :
833
840
EventEmitter .__init__ (self )
834
841
self .handle = 0
@@ -848,11 +855,11 @@ def __init__(
848
855
849
856
self .value = value
850
857
851
- def encode_value (self , value : Any ) -> bytes :
852
- return value
858
+ def encode_value (self , value : _T ) -> bytes :
859
+ return value # type: ignore
853
860
854
- def decode_value (self , value_bytes : bytes ) -> Any :
855
- return value_bytes
861
+ def decode_value (self , value : bytes ) -> _T :
862
+ return value # type: ignore
856
863
857
864
async def read_value (self , connection : Optional [Connection ]) -> bytes :
858
865
if (
@@ -877,32 +884,39 @@ async def read_value(self, connection: Optional[Connection]) -> bytes:
877
884
error_code = ATT_INSUFFICIENT_AUTHORIZATION_ERROR , att_handle = self .handle
878
885
)
879
886
880
- if hasattr (self .value , 'read' ):
887
+ value : Union [_T , None ]
888
+ if isinstance (self .value , AttributeValue ):
881
889
try :
882
- value = self .value .read (connection )
883
- if inspect .isawaitable (value ):
884
- value = await value
890
+ read_value = self .value .read (connection )
891
+ if inspect .isawaitable (read_value ):
892
+ value = await read_value
893
+ else :
894
+ value = read_value
885
895
except ATT_Error as error :
886
896
raise ATT_Error (
887
897
error_code = error .error_code , att_handle = self .handle
888
898
) from error
889
899
else :
890
900
value = self .value
891
901
892
- self .emit ('read' , connection , value )
902
+ self .emit ('read' , connection , b'' if value is None else value )
893
903
894
- return self .encode_value (value )
904
+ return b'' if value is None else self .encode_value (value )
895
905
896
- async def write_value (self , connection : Connection , value_bytes : bytes ) -> None :
906
+ async def write_value (self , connection : Optional [ Connection ], value : bytes ) -> None :
897
907
if (
898
- self .permissions & self .WRITE_REQUIRES_ENCRYPTION
899
- ) and not connection .encryption :
908
+ (self .permissions & self .WRITE_REQUIRES_ENCRYPTION )
909
+ and connection is not None
910
+ and not connection .encryption
911
+ ):
900
912
raise ATT_Error (
901
913
error_code = ATT_INSUFFICIENT_ENCRYPTION_ERROR , att_handle = self .handle
902
914
)
903
915
if (
904
- self .permissions & self .WRITE_REQUIRES_AUTHENTICATION
905
- ) and not connection .authenticated :
916
+ (self .permissions & self .WRITE_REQUIRES_AUTHENTICATION )
917
+ and connection is not None
918
+ and not connection .authenticated
919
+ ):
906
920
raise ATT_Error (
907
921
error_code = ATT_INSUFFICIENT_AUTHENTICATION_ERROR , att_handle = self .handle
908
922
)
@@ -912,21 +926,21 @@ async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
912
926
error_code = ATT_INSUFFICIENT_AUTHORIZATION_ERROR , att_handle = self .handle
913
927
)
914
928
915
- value = self .decode_value (value_bytes )
929
+ decoded_value = self .decode_value (value )
916
930
917
- if hasattr (self .value , 'write' ):
931
+ if isinstance (self .value , AttributeValue ):
918
932
try :
919
- result = self .value .write (connection , value )
933
+ result = self .value .write (connection , decoded_value )
920
934
if inspect .isawaitable (result ):
921
935
await result
922
936
except ATT_Error as error :
923
937
raise ATT_Error (
924
938
error_code = error .error_code , att_handle = self .handle
925
939
) from error
926
940
else :
927
- self .value = value
941
+ self .value = decoded_value
928
942
929
- self .emit ('write' , connection , value )
943
+ self .emit ('write' , connection , decoded_value )
930
944
931
945
def __repr__ (self ):
932
946
if isinstance (self .value , bytes ):
0 commit comments