diff --git a/bumble/device.py b/bumble/device.py index 031c0714..6ab0a12a 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -86,6 +86,7 @@ HCI_LE_Extended_Create_Connection_Command, HCI_LE_Rand_Command, HCI_LE_Read_PHY_Command, + HCI_LE_Set_Address_Resolution_Enable_Command, HCI_LE_Set_Advertising_Data_Command, HCI_LE_Set_Advertising_Enable_Command, HCI_LE_Set_Advertising_Parameters_Command, @@ -778,6 +779,7 @@ def __init__(self) -> None: self.irk = bytes(16) # This really must be changed for any level of security self.keystore = None self.gatt_services: List[Dict[str, Any]] = [] + self.address_resolution_offload = False def load_from_dict(self, config: Dict[str, Any]) -> None: # Load simple properties @@ -1029,6 +1031,7 @@ def __init__( self.discoverable = config.discoverable self.connectable = config.connectable self.classic_accept_any = config.classic_accept_any + self.address_resolution_offload = config.address_resolution_offload for service in config.gatt_services: characteristics = [] @@ -1256,31 +1259,16 @@ async def power_on(self) -> None: ) # Load the address resolving list - if self.keystore and self.host.supports_command( - HCI_LE_CLEAR_RESOLVING_LIST_COMMAND - ): - await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg] - - resolving_keys = await self.keystore.get_resolving_keys() - for irk, address in resolving_keys: - await self.send_command( - HCI_LE_Add_Device_To_Resolving_List_Command( - peer_identity_address_type=address.address_type, - peer_identity_address=address, - peer_irk=irk, - local_irk=self.irk, - ) # type: ignore[call-arg] - ) - - # Enable address resolution - # await self.send_command( - # HCI_LE_Set_Address_Resolution_Enable_Command( - # address_resolution_enable=1) - # ) - # ) + if self.keystore: + await self.refresh_resolving_list() - # Create a host-side address resolver - self.address_resolver = smp.AddressResolver(resolving_keys) + # Enable address resolution + if self.address_resolution_offload: + await self.send_command( + HCI_LE_Set_Address_Resolution_Enable_Command( + address_resolution_enable=1 + ) # type: ignore[call-arg] + ) if self.classic_enabled: await self.send_command( @@ -1310,6 +1298,26 @@ async def power_off(self) -> None: await self.host.flush() self.powered_on = False + async def refresh_resolving_list(self) -> None: + assert self.keystore is not None + + resolving_keys = await self.keystore.get_resolving_keys() + # Create a host-side address resolver + self.address_resolver = smp.AddressResolver(resolving_keys) + + if self.address_resolution_offload: + await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg] + + for irk, address in resolving_keys: + await self.send_command( + HCI_LE_Add_Device_To_Resolving_List_Command( + peer_identity_address_type=address.address_type, + peer_identity_address=address, + peer_irk=irk, + local_irk=self.irk, + ) # type: ignore[call-arg] + ) + def supports_le_feature(self, feature): return self.host.supports_le_feature(feature) diff --git a/bumble/smp.py b/bumble/smp.py index c93ee9c8..9588a5ac 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1272,7 +1272,7 @@ async def on_pairing(self) -> None: keys.link_key = PairingKeys.Key( value=self.link_key, authenticated=authenticated ) - self.manager.on_pairing(self, peer_address, keys) + await self.manager.on_pairing(self, peer_address, keys) def on_pairing_failure(self, reason: int) -> None: logger.warning(f'pairing failure ({error_name(reason)})') @@ -1827,20 +1827,13 @@ def request_pairing(self, connection: Connection) -> None: def on_session_start(self, session: Session) -> None: self.device.on_pairing_start(session.connection) - def on_pairing( + async def on_pairing( self, session: Session, identity_address: Optional[Address], keys: PairingKeys ) -> None: # Store the keys in the key store if self.device.keystore and identity_address is not None: - - async def store_keys(): - try: - assert self.device.keystore - await self.device.keystore.update(str(identity_address), keys) - except Exception as error: - logger.warning(f'!!! error while storing keys: {error}') - - self.device.abort_on('flush', store_keys()) + await self.device.keystore.update(str(identity_address), keys) + await self.device.refresh_resolving_list() # Notify the device self.device.on_pairing(session.connection, identity_address, keys, session.sc) diff --git a/tests/self_test.py b/tests/self_test.py index 4c350457..98ce5e80 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -68,13 +68,16 @@ def __init__(self): ), ] - self.paired = [None, None] + self.paired = [ + asyncio.get_event_loop().create_future(), + asyncio.get_event_loop().create_future(), + ] def on_connection(self, which, connection): self.connections[which] = connection - def on_paired(self, which, keys): - self.paired[which] = keys + def on_paired(self, which: int, keys: PairingKeys): + self.paired[which].set_result(keys) # ----------------------------------------------------------------------------- @@ -323,8 +326,8 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2): # Pair await two_devices.devices[0].pair(connection) assert connection.is_encrypted - assert two_devices.paired[0] is not None - assert two_devices.paired[1] is not None + assert await two_devices.paired[0] is not None + assert await two_devices.paired[1] is not None # ----------------------------------------------------------------------------- @@ -527,16 +530,12 @@ async def test_self_smp_over_classic(): two_devices.connections[0].encryption = 1 two_devices.connections[1].encryption = 1 - paired = [ - asyncio.get_event_loop().create_future(), - asyncio.get_event_loop().create_future(), - ] - - def on_pairing(which: int, keys: PairingKeys): - paired[which].set_result(keys) - - two_devices.connections[0].on('pairing', lambda keys: on_pairing(0, keys)) - two_devices.connections[1].on('pairing', lambda keys: on_pairing(1, keys)) + two_devices.connections[0].on( + 'pairing', lambda keys: two_devices.on_paired(0, keys) + ) + two_devices.connections[1].on( + 'pairing', lambda keys: two_devices.on_paired(1, keys) + ) # Mock SMP with patch('bumble.smp.Session', spec=True) as MockSmpSession: @@ -547,7 +546,7 @@ def on_pairing(which: int, keys: PairingKeys): # Start CTKD await two_devices.connections[0].pair() - await asyncio.gather(*paired) + await asyncio.gather(*two_devices.paired) # Phase 2 commands should not be invoked MockSmpSession.send_pairing_confirm_command.assert_not_called()