Skip to content

Commit ed0eb91

Browse files
authored
Merge pull request #650 from google/gbg/gatt-adapter-typing
new GATT adapter classes with proper typing support
2 parents 752ce6c + 82d8250 commit ed0eb91

31 files changed

+1256
-622
lines changed

apps/auracast.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,19 @@ async def run_assist(
514514
return
515515

516516
# Subscribe to and read the broadcast receive state characteristics
517+
def on_broadcast_receive_state_update(
518+
value: bass.BroadcastReceiveState, index: int
519+
) -> None:
520+
print(
521+
f"{color(f'Broadcast Receive State Update [{index}]:', 'green')} {value}"
522+
)
523+
517524
for i, broadcast_receive_state in enumerate(
518525
bass_client.broadcast_receive_states
519526
):
520527
try:
521528
await broadcast_receive_state.subscribe(
522-
lambda value, i=i: print(
523-
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
524-
)
529+
functools.partial(on_broadcast_receive_state_update, index=i)
525530
)
526531
except core.ProtocolError as error:
527532
print(

apps/gg_bridge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(self, device: Device):
234234
Characteristic.WRITEABLE,
235235
CharacteristicValue(write=self.on_rx_write),
236236
)
237-
self.tx_characteristic = Characteristic(
237+
self.tx_characteristic: Characteristic[bytes] = Characteristic(
238238
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
239239
Characteristic.Properties.NOTIFY,
240240
Characteristic.READABLE,

bumble/att.py

+51-37
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,33 @@
2929
import inspect
3030
import struct
3131
from typing import (
32-
Any,
3332
Awaitable,
3433
Callable,
34+
Generic,
3535
Dict,
3636
List,
3737
Optional,
3838
Type,
39+
TypeVar,
3940
Union,
4041
TYPE_CHECKING,
4142
)
4243

4344
from pyee import EventEmitter
4445

4546
from bumble import utils
46-
from bumble.core import UUID, name_or_number, ProtocolError
47+
from bumble.core import UUID, name_or_number, InvalidOperationError, ProtocolError
4748
from bumble.hci import HCI_Object, key_with_value
4849
from bumble.colors import color
4950

51+
# -----------------------------------------------------------------------------
52+
# Typing
53+
# -----------------------------------------------------------------------------
5054
if TYPE_CHECKING:
5155
from bumble.device import Connection
5256

57+
_T = TypeVar('_T')
58+
5359
# -----------------------------------------------------------------------------
5460
# Constants
5561
# -----------------------------------------------------------------------------
@@ -748,7 +754,7 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
748754

749755

750756
# -----------------------------------------------------------------------------
751-
class AttributeValue:
757+
class AttributeValue(Generic[_T]):
752758
'''
753759
Attribute value where reading and/or writing is delegated to functions
754760
passed as arguments to the constructor.
@@ -757,33 +763,34 @@ class AttributeValue:
757763
def __init__(
758764
self,
759765
read: Union[
760-
Callable[[Optional[Connection]], Any],
761-
Callable[[Optional[Connection]], Awaitable[Any]],
766+
Callable[[Optional[Connection]], _T],
767+
Callable[[Optional[Connection]], Awaitable[_T]],
762768
None,
763769
] = None,
764770
write: Union[
765-
Callable[[Optional[Connection], Any], None],
766-
Callable[[Optional[Connection], Any], Awaitable[None]],
771+
Callable[[Optional[Connection], _T], None],
772+
Callable[[Optional[Connection], _T], Awaitable[None]],
767773
None,
768774
] = None,
769775
):
770776
self._read = read
771777
self._write = write
772778

773-
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
774-
return self._read(connection) if self._read else b''
779+
def read(self, connection: Optional[Connection]) -> Union[_T, Awaitable[_T]]:
780+
if self._read is None:
781+
raise InvalidOperationError('AttributeValue has no read function')
782+
return self._read(connection)
775783

776784
def write(
777-
self, connection: Optional[Connection], value: bytes
785+
self, connection: Optional[Connection], value: _T
778786
) -> Union[Awaitable[None], None]:
779-
if self._write:
780-
return self._write(connection, value)
781-
782-
return None
787+
if self._write is None:
788+
raise InvalidOperationError('AttributeValue has no write function')
789+
return self._write(connection, value)
783790

784791

785792
# -----------------------------------------------------------------------------
786-
class Attribute(EventEmitter):
793+
class Attribute(EventEmitter, Generic[_T]):
787794
class Permissions(enum.IntFlag):
788795
READABLE = 0x01
789796
WRITEABLE = 0x02
@@ -822,13 +829,13 @@ def from_string(cls, permissions_str: str) -> Attribute.Permissions:
822829
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
823830
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
824831

825-
value: Any
832+
value: Union[AttributeValue[_T], _T, None]
826833

827834
def __init__(
828835
self,
829836
attribute_type: Union[str, bytes, UUID],
830837
permissions: Union[str, Attribute.Permissions],
831-
value: Any = b'',
838+
value: Union[AttributeValue[_T], _T, None] = None,
832839
) -> None:
833840
EventEmitter.__init__(self)
834841
self.handle = 0
@@ -848,11 +855,11 @@ def __init__(
848855

849856
self.value = value
850857

851-
def encode_value(self, value: Any) -> bytes:
852-
return value
858+
def encode_value(self, value: _T) -> bytes:
859+
return value # type: ignore
853860

854-
def decode_value(self, value_bytes: bytes) -> Any:
855-
return value_bytes
861+
def decode_value(self, value: bytes) -> _T:
862+
return value # type: ignore
856863

857864
async def read_value(self, connection: Optional[Connection]) -> bytes:
858865
if (
@@ -877,32 +884,39 @@ async def read_value(self, connection: Optional[Connection]) -> bytes:
877884
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
878885
)
879886

880-
if hasattr(self.value, 'read'):
887+
value: Union[_T, None]
888+
if isinstance(self.value, AttributeValue):
881889
try:
882-
value = self.value.read(connection)
883-
if inspect.isawaitable(value):
884-
value = await value
890+
read_value = self.value.read(connection)
891+
if inspect.isawaitable(read_value):
892+
value = await read_value
893+
else:
894+
value = read_value
885895
except ATT_Error as error:
886896
raise ATT_Error(
887897
error_code=error.error_code, att_handle=self.handle
888898
) from error
889899
else:
890900
value = self.value
891901

892-
self.emit('read', connection, value)
902+
self.emit('read', connection, b'' if value is None else value)
893903

894-
return self.encode_value(value)
904+
return b'' if value is None else self.encode_value(value)
895905

896-
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
906+
async def write_value(self, connection: Optional[Connection], value: bytes) -> None:
897907
if (
898-
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
899-
) and not connection.encryption:
908+
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
909+
and connection is not None
910+
and not connection.encryption
911+
):
900912
raise ATT_Error(
901913
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
902914
)
903915
if (
904-
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
905-
) and not connection.authenticated:
916+
(self.permissions & self.WRITE_REQUIRES_AUTHENTICATION)
917+
and connection is not None
918+
and not connection.authenticated
919+
):
906920
raise ATT_Error(
907921
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
908922
)
@@ -912,21 +926,21 @@ async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
912926
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
913927
)
914928

915-
value = self.decode_value(value_bytes)
929+
decoded_value = self.decode_value(value)
916930

917-
if hasattr(self.value, 'write'):
931+
if isinstance(self.value, AttributeValue):
918932
try:
919-
result = self.value.write(connection, value)
933+
result = self.value.write(connection, decoded_value)
920934
if inspect.isawaitable(result):
921935
await result
922936
except ATT_Error as error:
923937
raise ATT_Error(
924938
error_code=error.error_code, att_handle=self.handle
925939
) from error
926940
else:
927-
self.value = value
941+
self.value = decoded_value
928942

929-
self.emit('write', connection, value)
943+
self.emit('write', connection, decoded_value)
930944

931945
def __repr__(self):
932946
if isinstance(self.value, bytes):

bumble/device.py

+74-6
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
from .colors import color
5555
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
56-
from .gatt import Characteristic, Descriptor, Service
56+
from .gatt import Attribute, Characteristic, Descriptor, Service
5757
from .host import DataPacketQueue, Host
5858
from .profiles.gap import GenericAccessService
5959
from .core import (
@@ -2223,7 +2223,7 @@ def __init__(
22232223
permissions=descriptor["permissions"],
22242224
)
22252225
descriptors.append(new_descriptor)
2226-
new_characteristic = Characteristic(
2226+
new_characteristic: Characteristic[bytes] = Characteristic(
22272227
uuid=characteristic["uuid"],
22282228
properties=Characteristic.Properties.from_string(
22292229
characteristic["properties"]
@@ -4923,16 +4923,84 @@ def add_default_services(
49234923
self.gatt_service = gatt_service.GenericAttributeProfileService()
49244924
self.gatt_server.add_service(self.gatt_service)
49254925

4926-
async def notify_subscriber(self, connection, attribute, value=None, force=False):
4926+
async def notify_subscriber(
4927+
self,
4928+
connection: Connection,
4929+
attribute: Attribute,
4930+
value: Optional[Any] = None,
4931+
force: bool = False,
4932+
) -> None:
4933+
"""
4934+
Send a notification to an attribute subscriber.
4935+
4936+
Args:
4937+
connection:
4938+
The connection of the subscriber.
4939+
attribute:
4940+
The attribute whose value is notified.
4941+
value:
4942+
The value of the attribute (if None, the value is read from the attribute)
4943+
force:
4944+
If True, send a notification even if there is no subscriber.
4945+
"""
49274946
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
49284947

4929-
async def notify_subscribers(self, attribute, value=None, force=False):
4948+
async def notify_subscribers(
4949+
self, attribute: Attribute, value=None, force=False
4950+
) -> None:
4951+
"""
4952+
Send a notification to all the subscribers of an attribute.
4953+
4954+
Args:
4955+
attribute:
4956+
The attribute whose value is notified.
4957+
value:
4958+
The value of the attribute (if None, the value is read from the attribute)
4959+
force:
4960+
If True, send a notification for every connection even if there is no
4961+
subscriber.
4962+
"""
49304963
await self.gatt_server.notify_subscribers(attribute, value, force)
49314964

4932-
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
4965+
async def indicate_subscriber(
4966+
self,
4967+
connection: Connection,
4968+
attribute: Attribute,
4969+
value: Optional[Any] = None,
4970+
force: bool = False,
4971+
):
4972+
"""
4973+
Send an indication to an attribute subscriber.
4974+
4975+
This method returns when the response to the indication has been received.
4976+
4977+
Args:
4978+
connection:
4979+
The connection of the subscriber.
4980+
attribute:
4981+
The attribute whose value is indicated.
4982+
value:
4983+
The value of the attribute (if None, the value is read from the attribute)
4984+
force:
4985+
If True, send an indication even if there is no subscriber.
4986+
"""
49334987
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
49344988

4935-
async def indicate_subscribers(self, attribute, value=None, force=False):
4989+
async def indicate_subscribers(
4990+
self, attribute: Attribute, value: Optional[Any] = None, force: bool = False
4991+
):
4992+
"""
4993+
Send an indication to all the subscribers of an attribute.
4994+
4995+
Args:
4996+
attribute:
4997+
The attribute whose value is notified.
4998+
value:
4999+
The value of the attribute (if None, the value is read from the attribute)
5000+
force:
5001+
If True, send an indication for every connection even if there is no
5002+
subscriber.
5003+
"""
49365004
await self.gatt_server.indicate_subscribers(attribute, value, force)
49375005

49385006
@host_event_handler

0 commit comments

Comments
 (0)