Skip to content

Commit 318ea9a

Browse files
Added recursive to_dict support to AttrDict (#1892)
Fixes #1520
1 parent 8cc2ed2 commit 318ea9a

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

elasticsearch_dsl/utils.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ def _wrap(val: Any, obj_wrapper: Optional[Callable[[Any], Any]] = None) -> Any:
8686
return val
8787

8888

89+
def _recursive_to_dict(value: Any) -> Any:
90+
if hasattr(value, "to_dict"):
91+
return value.to_dict()
92+
elif isinstance(value, dict) or isinstance(value, AttrDict):
93+
return {k: _recursive_to_dict(v) for k, v in value.items()}
94+
elif isinstance(value, list) or isinstance(value, AttrList):
95+
return [recursive_to_dict(elem) for elem in value]
96+
else:
97+
return value
98+
99+
89100
class AttrList(Generic[_ValT]):
90101
def __init__(
91102
self, l: List[_ValT], obj_wrapper: Optional[Callable[[_ValT], Any]] = None
@@ -228,8 +239,10 @@ def __setattr__(self, name: str, value: _ValT) -> None:
228239
def __iter__(self) -> Iterator[str]:
229240
return iter(self._d_)
230241

231-
def to_dict(self) -> Dict[str, _ValT]:
232-
return self._d_
242+
def to_dict(self, recursive: bool = False) -> Dict[str, _ValT]:
243+
return cast(
244+
Dict[str, _ValT], _recursive_to_dict(self._d_) if recursive else self._d_
245+
)
233246

234247
def keys(self) -> Iterable[str]:
235248
return self._d_.keys()

tests/test_integration/_async/test_search.py

+21
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,27 @@ async def test_inner_hits_are_wrapped_in_response(
112112
)
113113

114114

115+
@pytest.mark.asyncio
116+
async def test_inner_hits_are_serialized_to_dict(
117+
async_data_client: AsyncElasticsearch,
118+
) -> None:
119+
s = AsyncSearch(index="git")[0:1].query(
120+
"has_parent", parent_type="repo", inner_hits={}, query=Q("match_all")
121+
)
122+
response = await s.execute()
123+
d = response.to_dict(recursive=True)
124+
assert isinstance(d, dict)
125+
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)
126+
127+
# iterating over the results changes the format of the internal AttrDict
128+
for hit in response:
129+
pass
130+
131+
d = response.to_dict(recursive=True)
132+
assert isinstance(d, dict)
133+
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)
134+
135+
115136
@pytest.mark.asyncio
116137
async def test_scan_respects_doc_types(async_data_client: AsyncElasticsearch) -> None:
117138
repos = [repo async for repo in Repository.search().scan()]

tests/test_integration/_sync/test_search.py

+21
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,27 @@ def test_inner_hits_are_wrapped_in_response(
104104
)
105105

106106

107+
@pytest.mark.sync
108+
def test_inner_hits_are_serialized_to_dict(
109+
data_client: Elasticsearch,
110+
) -> None:
111+
s = Search(index="git")[0:1].query(
112+
"has_parent", parent_type="repo", inner_hits={}, query=Q("match_all")
113+
)
114+
response = s.execute()
115+
d = response.to_dict(recursive=True)
116+
assert isinstance(d, dict)
117+
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)
118+
119+
# iterating over the results changes the format of the internal AttrDict
120+
for hit in response:
121+
pass
122+
123+
d = response.to_dict(recursive=True)
124+
assert isinstance(d, dict)
125+
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)
126+
127+
107128
@pytest.mark.sync
108129
def test_scan_respects_doc_types(data_client: Elasticsearch) -> None:
109130
repos = [repo for repo in Repository.search().scan()]

0 commit comments

Comments
 (0)