Skip to content

Commit e432e99

Browse files
committed
Add and for list and dict
1 parent e14f41f commit e432e99

File tree

2 files changed

+117
-20
lines changed

2 files changed

+117
-20
lines changed

libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import hashlib
1414
import inspect
15+
import json
1516
import logging
1617
import os
1718
import re
@@ -93,6 +94,9 @@
9394
def _get_comparison_string(
9495
oper: str, value: Any, bind_variables: List[str]
9596
) -> tuple[str, str]:
97+
if oper not in COMPARISON_MAP:
98+
raise ValueError(f"Invalid operator: {oper}")
99+
96100
# usual two sided operator case
97101
if COMPARISON_MAP[oper] != "":
98102
bind_l = len(bind_variables)
@@ -194,26 +198,30 @@ def _generate_condition(
194198

195199
if not isinstance(value, (dict, list, tuple)):
196200
# scalar-equality Clause
197-
bind_l = f":value{len(bind_variables)}"
201+
bind = f":value{len(bind_variables)}"
198202
bind_variables.append(value)
199203

200-
return SINGLE_MASK.format(key=metadata_key, oper="==", value_bind=bind_l)
204+
return SINGLE_MASK.format(key=metadata_key, oper="==", value_bind=bind)
201205

202206
elif isinstance(value, dict):
203207
# all values are filters
204208
result: str
205209
passings: str
206210

207211
# comparison operator keys
208-
if len(value.keys() - COMPARISON_MAP.keys()) == 0:
212+
if all(value_key.startswith("$") for value_key in value.keys()):
209213
not_dict = {}
210214

211215
passing_values = []
212216
comparison_values = []
213217

214218
for k, v in value.items():
215219
# if need to negate, cannot combine in single JSON_EXISTS
216-
if k in NOT_OPERS:
220+
if (
221+
k in NOT_OPERS
222+
or (k == "$eq" and isinstance(v, (list, dict)))
223+
or (k == "$ne" and isinstance(v, (list, dict)))
224+
):
217225
not_dict[k] = v
218226
continue
219227

@@ -255,7 +263,7 @@ def _generate_condition(
255263
f"NOT (JSON_EXISTS(metadata, '$.{metadata_key}'))"
256264
)
257265

258-
else: # for now only $nin
266+
elif k == "$nin": # for now only $nin
259267
result, passings = _get_comparison_string(k, v, bind_variables)
260268

261269
all_conditions.append(
@@ -265,6 +273,28 @@ def _generate_condition(
265273
)
266274
)
267275

276+
elif k == "$eq":
277+
bind_l = len(bind_variables)
278+
bind_variables.append(json.dumps(v))
279+
280+
all_conditions.append(
281+
"JSON_EQUAL("
282+
f" JSON_QUERY(metadata, '$.{metadata_key}' ),"
283+
f" JSON(:value{bind_l})"
284+
")"
285+
)
286+
287+
elif k == "$ne":
288+
bind_l = len(bind_variables)
289+
bind_variables.append(json.dumps(v))
290+
291+
all_conditions.append(
292+
"NOT (JSON_EQUAL("
293+
f" JSON_QUERY(metadata, '$.{metadata_key}' ),"
294+
f" JSON(:value{bind_l})"
295+
"))"
296+
)
297+
268298
res = " AND ".join(all_conditions)
269299

270300
if len(all_conditions) > 1:

libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,30 +2649,52 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def]
26492649

26502650
vs.add_texts(texts, metadatas)
26512651

2652-
"""SELECT json_exists('{
2653-
"id": "1",
2654-
"name": "Jason",
2655-
"age": 45,
2656-
"address": [
2657-
{
2658-
"street": "25 A street",
2659-
"city": "Mono Vista",
2660-
"zip": 94088,
2661-
"state": "CA",
2662-
}
2663-
],
2664-
"drinks": "tea",
2665-
}', '$.name?(@ in ("json","b"))');"""
2666-
26672652
filter_res: list[tuple[dict, list[str]]] = [
26682653
({"drinks": {"$exists": True}}, ["1", "3"]),
26692654
({"address.zip": 94088}, ["1"]),
26702655
({"name": {"$eq": "Jason"}}, ["1"]),
26712656
({"drinks": {"$ne": "tea"}}, ["3"]), # exits and not equal
2657+
({"drinks": {"$eq": ["soda", "tea"]}}, ["3"]),
2658+
({"drinks": {"$ne": ["soda", "tea"]}}, ["1"]),
2659+
(
2660+
{
2661+
"address[0]": {
2662+
"$eq": {
2663+
"street": "25 A street",
2664+
"city": "Mono Vista",
2665+
"zip": 94088,
2666+
"state": "CA",
2667+
}
2668+
}
2669+
},
2670+
["1"],
2671+
),
2672+
(
2673+
{
2674+
"address[0]": {
2675+
"$ne": {
2676+
"street": "25 A street",
2677+
"city": "Mono Vista",
2678+
"zip": 94088,
2679+
"state": "CA",
2680+
}
2681+
}
2682+
},
2683+
["2"],
2684+
),
26722685
(
26732686
{"$or": [{"drinks": {"$exists": False}}, {"drinks": {"$ne": "tea"}}]},
26742687
["2", "3"],
26752688
),
2689+
(
2690+
{
2691+
"$or": [
2692+
{"drinks": {"$exists": False}},
2693+
{"drinks": {"$ne": ["soda", "tea"]}},
2694+
]
2695+
},
2696+
["1", "2"],
2697+
),
26762698
({"age": {"$gt": 45, "$lt": 55}}, ["2"]),
26772699
({"age": {"$gt": 45}}, ["2", "3"]),
26782700
({"age": {"$lt": 55}}, ["1", "2"]),
@@ -2754,6 +2776,10 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def]
27542776
_f = {"ss')--": "HELLOE"}
27552777
result = vs.similarity_search("Hello", k=3, db_filter=_f)
27562778

2779+
with pytest.raises(ValueError, match="Invalid operator"):
2780+
_f = {"drinks": {"$neq": ["soda", "tea"]}}
2781+
result = vs.similarity_search("Hello", k=3, db_filter=_f)
2782+
27572783
drop_table_purge(connection, "TB10")
27582784

27592785

@@ -2821,10 +2847,47 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def]
28212847
({"address.zip": 94088}, ["1"]),
28222848
({"name": {"$eq": "Jason"}}, ["1"]),
28232849
({"drinks": {"$ne": "tea"}}, ["3"]), # exits and not equal
2850+
({"drinks": {"$eq": ["soda", "tea"]}}, ["3"]),
2851+
({"drinks": {"$ne": ["soda", "tea"]}}, ["1"]),
2852+
(
2853+
{
2854+
"address[0]": {
2855+
"$eq": {
2856+
"street": "25 A street",
2857+
"city": "Mono Vista",
2858+
"zip": 94088,
2859+
"state": "CA",
2860+
}
2861+
}
2862+
},
2863+
["1"],
2864+
),
2865+
(
2866+
{
2867+
"address[0]": {
2868+
"$ne": {
2869+
"street": "25 A street",
2870+
"city": "Mono Vista",
2871+
"zip": 94088,
2872+
"state": "CA",
2873+
}
2874+
}
2875+
},
2876+
["2"],
2877+
),
28242878
(
28252879
{"$or": [{"drinks": {"$exists": False}}, {"drinks": {"$ne": "tea"}}]},
28262880
["2", "3"],
28272881
),
2882+
(
2883+
{
2884+
"$or": [
2885+
{"drinks": {"$exists": False}},
2886+
{"drinks": {"$ne": ["soda", "tea"]}},
2887+
]
2888+
},
2889+
["1", "2"],
2890+
),
28282891
({"age": {"$gt": 45, "$lt": 55}}, ["2"]),
28292892
({"age": {"$gt": 45}}, ["2", "3"]),
28302893
({"age": {"$lt": 55}}, ["1", "2"]),
@@ -2906,4 +2969,8 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def]
29062969
_f = {"ss')--": "HELLOE"}
29072970
result = await vs.asimilarity_search("Hello", k=3, db_filter=_f)
29082971

2972+
with pytest.raises(ValueError, match="Invalid operator"):
2973+
_f = {"drinks": {"$neq": ["soda", "tea"]}}
2974+
result = await vs.asimilarity_search("Hello", k=3, db_filter=_f)
2975+
29092976
await adrop_table_purge(connection, "TB10")

0 commit comments

Comments
 (0)