diff --git a/tests/test_application_helpers.py b/tests/test_application_helpers.py index 1faa9f9f1..f1af4060b 100644 --- a/tests/test_application_helpers.py +++ b/tests/test_application_helpers.py @@ -1,8 +1,12 @@ """Test zha application helpers.""" +from typing import Any + +import pytest from zigpy.device import Device as ZigpyDevice from zigpy.profiles import zha -from zigpy.zcl.clusters.general import Basic, OnOff +import zigpy.types as t +from zigpy.zcl.clusters.general import Basic, Identify, OnOff from zigpy.zcl.clusters.security import IasZone from tests.common import ( @@ -14,7 +18,12 @@ join_zigpy_device, ) from zha.application.gateway import Gateway -from zha.application.helpers import async_is_bindable_target, get_matched_clusters +from zha.application.helpers import ( + async_is_bindable_target, + convert_to_zcl_values, + convert_zcl_value, + get_matched_clusters, +) IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" @@ -105,3 +114,76 @@ async def test_get_matched_clusters( assert matches[0].target_ep_id == 1 assert not await get_matched_clusters(not_bindable_zha_device, remote_zha_device) + + +class SomeEnum(t.enum8): + """Some enum.""" + + value_1 = 0x12 + value_2 = 0x34 + value_3 = 0x56 + + +class SomeFlag(t.bitmap8): + """Some bitmap.""" + + flag_1 = 0b00000001 + flag_2 = 0b00000010 + flag_3 = 0b00000100 + + +@pytest.mark.parametrize( + ("text", "field_type", "result"), + [ + # Bytes + ( + "b'Some data\\x00\\x01'", + t.SerializableBytes, + t.SerializableBytes(b"Some data\x00\x01"), + ), + ( + 'b"Some data\\x00\\x01"', + t.SerializableBytes, + t.SerializableBytes(b"Some data\x00\x01"), + ), + ( + b"Some data\x00\x01".hex(), + t.SerializableBytes, + t.SerializableBytes(b"Some data\x00\x01"), + ), + # Enum + ("value 1", SomeEnum, SomeEnum.value_1), + ("value_1", SomeEnum, SomeEnum.value_1), + ("SomeEnum.value_1", SomeEnum, SomeEnum.value_1), + (0x12, SomeEnum, SomeEnum.value_1), + # Flag + ("flag 1", SomeFlag, SomeFlag.flag_1), + ("flag_1", SomeFlag, SomeFlag.flag_1), + ("SomeFlag.flag_1", SomeFlag, SomeFlag.flag_1), + ("SomeFlag.flag_1|flag_2", SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2), + (0b00000001, SomeFlag, SomeFlag.flag_1), + ([0b00000001], SomeFlag, SomeFlag.flag_1), + ([0b00000001, 0b00000010], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2), + (["flag_1", "flag_2"], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2), + # Int + (0x1234, t.uint16_t, 0x1234), + ("0x1234", t.uint16_t, 0x1234), + ("4660", t.uint16_t, 0x1234), + # Some fallthrough type + (1.000, t.Single, t.Single(1.000)), + ("1.000", t.Single, t.Single(1.000)), + ], +) +def test_convert_zcl_value(text: Any, field_type: Any, result: Any) -> None: + """Test converting ZCL values.""" + assert convert_zcl_value(text, field_type) == result + + +def test_convert_to_zcl_values() -> None: + """Test converting ZCL values.""" + + identify_schema = Identify.ServerCommandDefs.identify.schema + assert convert_to_zcl_values( + fields={"identify_time": "1"}, + schema=identify_schema, + ) == {"identify_time": 1} diff --git a/zha/application/helpers.py b/zha/application/helpers.py index 300de0078..f6ffa7638 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -2,10 +2,12 @@ from __future__ import annotations +import ast import asyncio import binascii import collections from collections.abc import Callable +import contextlib import dataclasses from dataclasses import dataclass import datetime @@ -126,6 +128,53 @@ async def get_matched_clusters( return clusters_to_bind +def convert_zcl_value(value: Any, field_type: Any) -> Any: + """Convert user input to ZCL value.""" + if issubclass(field_type, enum.Flag): + if isinstance(value, str): + with contextlib.suppress(ValueError): + value = int(value) + + if isinstance(value, int): + value = field_type(value) + elif isinstance(value, str): + # List of flags: `SomeFlag.field1 | field2` + value = [v.strip() for v in value.split(".", 1)[-1].split("|")] + + if isinstance(value, list): + new_value = 0 + + for flag in value: + if isinstance(flag, str): + new_value |= field_type[flag.replace(" ", "_")] + else: + new_value |= flag + + value = field_type(new_value) + elif issubclass(field_type, enum.Enum): + value = ( + field_type[value.replace(" ", "_").split(".", 1)[-1]] + if isinstance(value, str) + else field_type(value) + ) + elif issubclass(field_type, zigpy.types.SerializableBytes): + if value.startswith(("b'", 'b"')): + value = ast.literal_eval(value) + else: + value = bytes.fromhex(value) + + value = field_type(value) + elif issubclass(field_type, int): + if isinstance(value, str) and value.startswith("0x"): + value = int(value, 16) + + value = field_type(value) + else: + value = field_type(value) + + return value + + def convert_to_zcl_values( fields: dict[str, Any], schema: CommandSchema ) -> dict[str, Any]: @@ -134,32 +183,17 @@ def convert_to_zcl_values( for field in schema.fields: if field.name not in fields: continue - value = fields[field.name] - if issubclass(field.type, enum.Flag) and isinstance(value, list): - new_value = 0 - for flag in value: - if isinstance(flag, str): - new_value |= field.type[flag.replace(" ", "_")] - else: - new_value |= flag + value = fields[field.name] + new_value = converted_fields[field.name] = convert_zcl_value(value, field.type) - value = field.type(new_value) - elif issubclass(field.type, enum.Enum): - value = ( - field.type[value.replace(" ", "_")] - if isinstance(value, str) - else field.type(value) - ) - else: - value = field.type(value) _LOGGER.debug( "Converted ZCL schema field(%s) value from: %s to: %s", field.name, - fields[field.name], value, + new_value, ) - converted_fields[field.name] = value + return converted_fields diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index ef2a51d67..c1a2c3b80 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -60,7 +60,7 @@ ZHA_CLUSTER_HANDLER_MSG, ZHA_EVENT, ) -from zha.application.helpers import convert_to_zcl_values +from zha.application.helpers import convert_to_zcl_values, convert_zcl_value from zha.application.platforms import BaseEntityInfo, PlatformEntity from zha.event import EventBase from zha.exceptions import ZHAException @@ -874,6 +874,9 @@ async def write_zigbee_attribute( f" writing attribute {attribute} with value {value}" ) from exc + attr_def = cluster.find_attribute(attribute) + value = convert_zcl_value(value, attr_def.type) + try: response = await cluster.write_attributes( {attribute: value}, manufacturer=manufacturer