Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework some device and hci functions #514

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 154 additions & 78 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,37 +1493,55 @@ async def sustain(self, timeout: Optional[float] = None) -> None:
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)

async def set_data_length(self, tx_octets, tx_time) -> None:
async def set_data_length(self, tx_octets: int, tx_time: int) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)

async def update_parameters(
self,
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout,
use_l2cap=False,
):
connection_interval_min: int,
connection_interval_max: int,
max_latency: int,
supervision_timeout: int,
use_l2cap: bool = False,
wait_for_complete: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe wait_for_completion would be a better name for this parameter.

) -> None:
return await self.device.update_connection_parameters(
self,
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout,
use_l2cap=use_l2cap,
wait_for_complete=wait_for_complete,
)

async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options)
async def set_phy(
self,
tx_phys: Union[Iterable[int], int, None] = None,
rx_phys: Union[Iterable[int], int, None] = None,
phy_options: int = 0,
wait_for_complete: bool = False,
) -> None:
"""Sets PHY preference of this connection.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this docstring is essentially the same as the one for device.set_connection_phy, maybe there's a clever way to just say "See XYZ" and refer to that other location, rather than have a copy of the docstring, which is hard to maintain in both places.


Args:
tx_phys: TX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
rx_phys: RX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
phy_options: Option of Coded PHY, where 0: No Preference, 1: Prefer S=2, 2: S=8.
wait_for_complete: If set, call will wait for connection PHY update event.
"""
return await self.device.set_connection_phy(
self, tx_phys, rx_phys, phy_options, wait_for_complete
)

async def get_rssi(self):
async def get_rssi(self) -> int:
return await self.device.get_connection_rssi(self)

async def get_phy(self):
async def get_phy(self) -> Tuple[int, int]:
return await self.device.get_connection_phy(self)

# [Classic only]
async def request_remote_name(self):
async def request_remote_name(self) -> str:
return await self.device.request_remote_name(self)

async def get_remote_le_features(self) -> LeFeatureMask:
Expand Down Expand Up @@ -3333,14 +3351,19 @@ async def disconnect(
)
self.disconnecting = False

async def set_data_length(self, connection, tx_octets, tx_time) -> None:
async def set_data_length(
self,
connection: Connection,
tx_octets: int,
tx_time: int,
) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB:
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')

if tx_time < 0x0148 or tx_time > 0x4290:
raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')

return await self.send_command(
await self.send_command(
HCI_LE_Set_Data_Length_Command(
connection_handle=connection.handle,
tx_octets=tx_octets,
Expand All @@ -3351,14 +3374,15 @@ async def set_data_length(self, connection, tx_octets, tx_time) -> None:

async def update_connection_parameters(
self,
connection,
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout,
min_ce_length=0,
max_ce_length=0,
use_l2cap=False,
connection: Connection,
connection_interval_min: int,
connection_interval_max: int,
max_latency: int,
supervision_timeout: int,
min_ce_length: int = 0,
max_ce_length: int = 0,
use_l2cap: bool = False,
wait_for_complete: bool = False,
) -> None:
'''
NOTE: the name of the parameters may look odd, but it just follows the names
Expand All @@ -3382,36 +3406,67 @@ async def update_connection_parameters(
if l2cap_result != l2cap.L2CAP_CONNECTION_PARAMETERS_ACCEPTED_RESULT:
raise ConnectionParameterUpdateError(l2cap_result)

result = await self.send_command(
HCI_LE_Connection_Update_Command(
connection_handle=connection.handle,
connection_interval_min=connection_interval_min,
connection_interval_max=connection_interval_max,
max_latency=max_latency,
supervision_timeout=supervision_timeout,
min_ce_length=min_ce_length,
max_ce_length=max_ce_length,
with closing(EventWatcher()) as watcher:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: move the call to self.send_command() to a local function, which the code followibng with closing(...) ...: can invoke, but can also be invoked without the with close when wait_for_completion is False, so that we don't unnecessarily create the context and the future in that case.

connection_update = asyncio.get_running_loop().create_future()
if wait_for_complete:
watcher.once(
connection,
'connection_parameters_update',
lambda: connection_update.set_result(None),
)
watcher.once(
connection,
'connection_parameters_update_failure',
lambda error_code: connection_update.set_exception(
HCI_Error(error_code)
),
)
await self.send_command(
HCI_LE_Connection_Update_Command(
connection_handle=connection.handle,
connection_interval_min=connection_interval_min,
connection_interval_max=connection_interval_max,
max_latency=max_latency,
supervision_timeout=supervision_timeout,
min_ce_length=min_ce_length,
max_ce_length=max_ce_length,
),
check_result=True,
)
)
if result.status != HCI_Command_Status_Event.PENDING:
raise HCI_StatusError(result)

async def get_connection_rssi(self, connection):
if wait_for_complete:
await connection_update

async def get_connection_rssi(self, connection: Connection) -> int:
result = await self.send_command(
HCI_Read_RSSI_Command(handle=connection.handle), check_result=True
)
return result.return_parameters.rssi

async def get_connection_phy(self, connection):
async def get_connection_phy(self, connection: Connection) -> Tuple[int, int]:
result = await self.send_command(
HCI_LE_Read_PHY_Command(connection_handle=connection.handle),
check_result=True,
)
return (result.return_parameters.tx_phy, result.return_parameters.rx_phy)

async def set_connection_phy(
self, connection, tx_phys=None, rx_phys=None, phy_options=None
):
self,
connection: Connection,
tx_phys: Union[Iterable[int], int, None] = None,
rx_phys: Union[Iterable[int], int, None] = None,
phy_options: int = 0,
wait_for_complete: bool = False,
) -> None:
"""Sets PHY preference of given connection.

Args:
connection: Connection to set PHY preference.
tx_phys: TX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
rx_phys: RX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
phy_options: Option of Coded PHY, where 0: No Preference, 1: Prefer S=2, 2: S=8.
wait_for_complete: If set, call will wait for connection PHY update event.
"""
if not self.host.supports_command(HCI_LE_SET_PHY_COMMAND):
logger.warning('ignoring request, command not supported')
return
Expand All @@ -3420,33 +3475,66 @@ async def set_connection_phy(
(1 if rx_phys is None else 0) << 1
)

result = await self.send_command(
HCI_LE_Set_PHY_Command(
connection_handle=connection.handle,
all_phys=all_phys_bits,
tx_phys=phy_list_to_bits(tx_phys),
rx_phys=phy_list_to_bits(rx_phys),
phy_options=0 if phy_options is None else int(phy_options),
)
)
with closing(EventWatcher()) as watcher:
phy_update = asyncio.get_running_loop().create_future()
if wait_for_complete:
watcher.once(
connection,
'connection_phy_update',
lambda: phy_update.set_result(None),
)
watcher.once(
connection,
'connection_phy_update_failure',
lambda error_code: phy_update.set_exception(HCI_Error(error_code)),
)

if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Set_PHY_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
await self.send_command(
HCI_LE_Set_PHY_Command(
connection_handle=connection.handle,
all_phys=all_phys_bits,
tx_phys=(
tx_phys
if isinstance(tx_phys, int)
else phy_list_to_bits(tx_phys)
),
rx_phys=(
rx_phys
if isinstance(rx_phys, int)
else phy_list_to_bits(rx_phys)
),
phy_options=phy_options,
),
check_result=True,
)
raise HCI_StatusError(result)

async def set_default_phy(self, tx_phys=None, rx_phys=None):
if wait_for_complete:
await phy_update

async def set_default_phy(
self,
tx_phys: Union[Iterable[int], int, None] = None,
rx_phys: Union[Iterable[int], int, None] = None,
) -> None:
"""Sets default PHY preference.

Args:
tx_phys: TX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
rx_phys: RX PHY preference. Could be an integer of PHY bit flag, a list of PHY, or None if no request.
"""
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
)

return await self.send_command(
await self.send_command(
HCI_LE_Set_Default_PHY_Command(
all_phys=all_phys_bits,
tx_phys=phy_list_to_bits(tx_phys),
rx_phys=phy_list_to_bits(rx_phys),
tx_phys=(
tx_phys if isinstance(tx_phys, int) else phy_list_to_bits(tx_phys)
),
rx_phys=(
rx_phys if isinstance(rx_phys, int) else phy_list_to_bits(rx_phys)
),
),
check_result=True,
)
Expand Down Expand Up @@ -3526,10 +3614,10 @@ def smp_session_proxy(self) -> Type[smp.Session]:
def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy

async def pair(self, connection):
async def pair(self, connection: Connection) -> None:
return await self.smp_manager.pair(connection)

def request_pairing(self, connection):
def request_pairing(self, connection: Connection) -> None:
return self.smp_manager.request_pairing(connection)

async def get_long_term_key(
Expand Down Expand Up @@ -3576,7 +3664,7 @@ async def get_link_key(self, address: Address) -> Optional[bytes]:
return keys.link_key.value

# [Classic only]
async def authenticate(self, connection):
async def authenticate(self, connection: Connection) -> None:
# Set up event handlers
pending_authentication = asyncio.get_running_loop().create_future()

Expand Down Expand Up @@ -3611,7 +3699,7 @@ def on_authentication_failure(error_code):
'connection_authentication_failure', on_authentication_failure
)

async def encrypt(self, connection, enable=True):
async def encrypt(self, connection: Connection, enable: bool = True) -> None:
if not enable and connection.transport == BT_LE_TRANSPORT:
raise InvalidArgumentError('`enable` parameter is classic only.')

Expand Down Expand Up @@ -3652,36 +3740,24 @@ def on_encryption_failure(error_code):
if connection.role != HCI_CENTRAL_ROLE:
raise InvalidStateError('only centrals can start encryption')

result = await self.send_command(
await self.send_command(
HCI_LE_Enable_Encryption_Command(
connection_handle=connection.handle,
random_number=rand,
encrypted_diversifier=ediv,
long_term_key=ltk,
)
),
check_result=True,
)

if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Enable_Encryption_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
else:
result = await self.send_command(
await self.send_command(
HCI_Set_Connection_Encryption_Command(
connection_handle=connection.handle,
encryption_enable=0x01 if enable else 0x00,
)
),
check_result=True,
)

if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_Set_Connection_Encryption_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

# Wait for the result
await connection.abort_on('disconnection', pending_encryption)
finally:
Expand Down
9 changes: 4 additions & 5 deletions bumble/hci.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def hci_command_op_code(ogf, ocf):
def hci_command_op_code(ogf: int, ocf: int) -> int:
return ogf << 10 | ocf


def hci_vendor_command_op_code(ocf):
def hci_vendor_command_op_code(ocf: int) -> int:
return hci_command_op_code(HCI_VENDOR_OGF, ocf)


Expand All @@ -65,7 +65,7 @@ def key_with_value(dictionary, target_value):
return None


def indent_lines(string):
def indent_lines(string: str) -> str:
return '\n'.join([' ' + line for line in string.split('\n')])


Expand All @@ -79,7 +79,7 @@ def map_null_terminated_utf8_string(utf8_bytes):
return utf8_bytes


def map_class_of_device(class_of_device):
def map_class_of_device(class_of_device: int) -> str:
(
service_classes,
major_device_class,
Expand Down Expand Up @@ -710,7 +710,6 @@ class PhyBit(enum.IntFlag):
LE_2M = 1 << HCI_LE_2M_PHY_BIT
LE_CODED = 1 << HCI_LE_CODED_PHY_BIT


# Connection Parameters
HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25
HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25
Expand Down
Loading