|
| 1 | +# ------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for |
| 4 | +# license information. |
| 5 | +# -------------------------------------------------------------------------- |
| 6 | +from typing import Any, Optional, Mapping, Union, Dict, Callable, cast |
| 7 | +from datetime import datetime, timezone |
| 8 | +from urllib.parse import quote |
| 9 | +from uuid import UUID |
| 10 | + |
| 11 | +from ._common_conversion import _decode_base64_to_bytes |
| 12 | +from ._entity import EntityProperty, EdmType, TableEntity, EntityMetadata |
| 13 | + |
| 14 | +DecoderMapType = Dict[EdmType, Callable[[Union[str, bool, int, float]], Any]] |
| 15 | + |
| 16 | + |
| 17 | +class TablesEntityDatetime(datetime): |
| 18 | + _service_value: str |
| 19 | + |
| 20 | + @property |
| 21 | + def tables_service_value(self) -> str: |
| 22 | + try: |
| 23 | + return self._service_value |
| 24 | + except AttributeError: |
| 25 | + return "" |
| 26 | + |
| 27 | + |
| 28 | +NO_ODATA = { |
| 29 | + int: EdmType.INT32, |
| 30 | + str: EdmType.STRING, |
| 31 | + bool: EdmType.BOOLEAN, |
| 32 | + float: EdmType.DOUBLE, |
| 33 | +} |
| 34 | + |
| 35 | + |
| 36 | +class TableEntityDecoder: |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + *, |
| 40 | + flatten_result_entity: bool = False, |
| 41 | + convert_map: Optional[DecoderMapType] = None, |
| 42 | + ) -> None: |
| 43 | + self.convert_map = convert_map |
| 44 | + self.flatten_result_entity = flatten_result_entity |
| 45 | + |
| 46 | + def __call__( # pylint: disable=too-many-branches, too-many-statements |
| 47 | + self, response_data: Mapping[str, Any] |
| 48 | + ) -> TableEntity: |
| 49 | + """Convert json response to entity. |
| 50 | + The entity format is: |
| 51 | + { |
| 52 | + "Address":"Mountain View", |
| 53 | + "Age":23, |
| 54 | + "AmountDue":200.23, |
| 55 | + |
| 56 | + "CustomerCode":"c9da6455-213d-42c9-9a79-3e9149a57833", |
| 57 | + "[email protected]":"Edm.DateTime", |
| 58 | + "CustomerSince":"2008-07-10T00:00:00", |
| 59 | + "IsActive":true, |
| 60 | + |
| 61 | + "NumberOfOrders":"255", |
| 62 | + "PartitionKey":"my_partition_key", |
| 63 | + "RowKey":"my_row_key" |
| 64 | + } |
| 65 | +
|
| 66 | + :param response_data: The entity in response. |
| 67 | + :type response_data: Mapping[str, Any] |
| 68 | + :return: An entity dict with additional metadata. |
| 69 | + :rtype: dict[str, Any] |
| 70 | + """ |
| 71 | + entity = TableEntity() |
| 72 | + |
| 73 | + properties = {} |
| 74 | + edmtypes = {} |
| 75 | + odata = {} |
| 76 | + |
| 77 | + for name, value in response_data.items(): |
| 78 | + if name.startswith("odata."): |
| 79 | + odata[name[6:]] = value |
| 80 | + elif name.endswith("@odata.type"): |
| 81 | + edmtypes[name[:-11]] = value |
| 82 | + else: |
| 83 | + properties[name] = value |
| 84 | + |
| 85 | + # Partitionkey is a known property |
| 86 | + partition_key = properties.pop("PartitionKey", None) |
| 87 | + if partition_key is not None: |
| 88 | + entity["PartitionKey"] = partition_key |
| 89 | + |
| 90 | + # Timestamp is a known property |
| 91 | + timestamp = properties.pop("Timestamp", None) |
| 92 | + |
| 93 | + for name, value in properties.items(): |
| 94 | + mtype = edmtypes.get(name) |
| 95 | + |
| 96 | + if not mtype: |
| 97 | + mtype = NO_ODATA[type(value)] |
| 98 | + |
| 99 | + convert = None |
| 100 | + default_convert = None |
| 101 | + if self.convert_map: |
| 102 | + try: |
| 103 | + convert = self.convert_map[mtype] |
| 104 | + except KeyError: |
| 105 | + pass |
| 106 | + if convert: |
| 107 | + new_property = convert(value) |
| 108 | + else: |
| 109 | + try: |
| 110 | + default_convert = _ENTITY_TO_PYTHON_CONVERSIONS[mtype] |
| 111 | + except KeyError as e: |
| 112 | + raise TypeError(f"Unsupported edm type: {mtype}") from e |
| 113 | + if default_convert is not None: |
| 114 | + new_property = default_convert(self, value) |
| 115 | + else: |
| 116 | + new_property = EntityProperty(mtype, value) |
| 117 | + entity[name] = new_property |
| 118 | + |
| 119 | + # extract etag from entry |
| 120 | + etag = odata.pop("etag", None) |
| 121 | + odata.pop("metadata", None) |
| 122 | + if timestamp: |
| 123 | + if not etag: |
| 124 | + etag = "W/\"datetime'" + quote(timestamp) + "'\"" |
| 125 | + timestamp = self.from_entity_datetime(timestamp) |
| 126 | + odata.update({"etag": etag, "timestamp": timestamp}) |
| 127 | + if self.flatten_result_entity: |
| 128 | + for name, value in odata.items(): |
| 129 | + entity[name] = value |
| 130 | + entity._metadata = cast(EntityMetadata, odata) # pylint: disable=protected-access |
| 131 | + return entity |
| 132 | + |
| 133 | + def from_entity_binary(self, value: str) -> bytes: |
| 134 | + return _decode_base64_to_bytes(value) |
| 135 | + |
| 136 | + def from_entity_int32(self, value: Union[int, str]) -> int: |
| 137 | + return int(value) |
| 138 | + |
| 139 | + def from_entity_int64(self, value: str) -> EntityProperty: |
| 140 | + return EntityProperty(int(value), EdmType.INT64) |
| 141 | + |
| 142 | + def from_entity_datetime(self, value: str) -> Optional[TablesEntityDatetime]: |
| 143 | + return deserialize_iso(value) |
| 144 | + |
| 145 | + def from_entity_guid(self, value: str) -> UUID: |
| 146 | + return UUID(value) |
| 147 | + |
| 148 | + def from_entity_str(self, value: Union[str, bytes]) -> str: |
| 149 | + if isinstance(value, bytes): |
| 150 | + return value.decode("utf-8") |
| 151 | + return value |
| 152 | + |
| 153 | + |
| 154 | +_ENTITY_TO_PYTHON_CONVERSIONS = { |
| 155 | + EdmType.BINARY: TableEntityDecoder.from_entity_binary, |
| 156 | + EdmType.INT32: TableEntityDecoder.from_entity_int32, |
| 157 | + EdmType.INT64: TableEntityDecoder.from_entity_int64, |
| 158 | + EdmType.DOUBLE: lambda _, v: float(v), |
| 159 | + EdmType.DATETIME: TableEntityDecoder.from_entity_datetime, |
| 160 | + EdmType.GUID: TableEntityDecoder.from_entity_guid, |
| 161 | + EdmType.STRING: TableEntityDecoder.from_entity_str, |
| 162 | + EdmType.BOOLEAN: lambda _, v: v, |
| 163 | +} |
| 164 | + |
| 165 | + |
| 166 | +def deserialize_iso(value: Optional[str]) -> Optional[TablesEntityDatetime]: |
| 167 | + if not value: |
| 168 | + return None |
| 169 | + # Cosmos returns this with a decimal point that throws an error on deserialization |
| 170 | + cleaned_value = _clean_up_dotnet_timestamps(value) |
| 171 | + try: |
| 172 | + dt_obj = TablesEntityDatetime.strptime(cleaned_value, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) |
| 173 | + except ValueError: |
| 174 | + dt_obj = TablesEntityDatetime.strptime(cleaned_value, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) |
| 175 | + dt_obj._service_value = value # pylint:disable=protected-access,assigning-non-slot |
| 176 | + return dt_obj |
| 177 | + |
| 178 | + |
| 179 | +def _clean_up_dotnet_timestamps(value): |
| 180 | + # .NET has more decimal places than Python supports in datetime objects, this truncates |
| 181 | + # values after 6 decimal places. |
| 182 | + value = value.split(".") |
| 183 | + ms = "" |
| 184 | + if len(value) == 2: |
| 185 | + ms = value[-1].replace("Z", "") |
| 186 | + if len(ms) > 6: |
| 187 | + ms = ms[:6] |
| 188 | + ms = ms + "Z" |
| 189 | + return ".".join([value[0], ms]) |
| 190 | + return value[0] |
0 commit comments