Skip to content

Commit 0fd8574

Browse files
committed
fix: allow extra filtering fields
1 parent 7dd2f2c commit 0fd8574

File tree

8 files changed

+98
-4
lines changed

8 files changed

+98
-4
lines changed

docs/index.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Add querystring filters to your api endpoints and show them in the swagger UI.
44

55
The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy),
6-
[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie).
6+
[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie).
77

88
## Example
99

@@ -133,6 +133,34 @@ class UserFilter(Filter):
133133
address__country: Optional[str]
134134
```
135135

136+
### Extra fields
137+
138+
Sometimes, you may need to add extra fields to your filter for custom logic. To do this, follow these steps:
139+
140+
1. Define the extra fields in your filter class.
141+
2. List these fields in the `extra_fields` constant.
142+
3. Access these fields in your endpoint to implement custom filtering logic.
143+
144+
```python
145+
class UserFilter(Filter):
146+
name: Optional[str]
147+
is_not_active: Optional[bool]
148+
149+
class Constants(Filter.Constants):
150+
model = User
151+
extra_fields = ["is_not_active"]
152+
153+
@app.get("/users", response_model=list[UserOut])
154+
async def get_users(user_filter: UserFilter = FilterDepends(UserFilter), db: AsyncSession = Depends(get_db)) -> Any:
155+
query = user_filter.filter(select(User))
156+
157+
if user_filter.is_not_active is not None:
158+
query = query.where(User.is_active.is_(not user_filter.is_not_active))
159+
160+
result = await db.execute(query)
161+
return result.scalars().all()
162+
```
163+
136164
## Order by
137165

138166
There is a specific field on the filter class that can be used for ordering. The default name is `order_by` and it

fastapi_filter/base/filter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Constants: # pragma: no cover
4949
ordering_field_name: str = "order_by"
5050
search_model_fields: list[str]
5151
search_field_name: str = "search"
52+
extra_fields: list[str] = []
5253
prefix: str
5354
original_filter: type["BaseFilterModel"]
5455

@@ -59,6 +60,8 @@ def filter(self, query): # pragma: no cover
5960
def filtering_fields(self):
6061
fields = self.model_dump(exclude_none=True, exclude_unset=True)
6162
fields.pop(self.Constants.ordering_field_name, None)
63+
for field_name in self.Constants.extra_fields:
64+
fields.pop(field_name, None)
6265
return fields.items()
6366

6467
def sort(self, query): # pragma: no cover

tests/test_beanie/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class User(Document):
3131
name: Optional[str] = None
3232
email: Optional[EmailStr] = None
3333
age: int
34+
is_active: bool = True
3435
address: Optional[Link[Address]] = None
3536
favorite_sports: Optional[list[Link[Sport]]] = []
3637

@@ -86,12 +87,14 @@ async def users(
8687
await User(
8788
name=None,
8889
age=21,
90+
is_active=False,
8991
created_at=datetime(2021, 12, 1),
9092
favorite_sports=sports,
9193
).save(link_rule=WriteRules.WRITE),
9294
await User(
9395
name="Mr Praline",
9496
age=33,
97+
is_active=False,
9598
created_at=datetime(2021, 12, 1),
9699
address=Address(street="22 rue Bellier", city="Nantes", country="France"),
97100
favorite_sports=[sports[0]],
@@ -154,6 +157,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
154157
age__gt: Optional[int] = None
155158
age__gte: Optional[int] = None
156159
age__in: Optional[list[int]] = None
160+
gender: Optional[str] = None
161+
is_not_active: Optional[bool] = None
157162
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
158163
with_prefix("address", AddressFilter),
159164
)
@@ -164,6 +169,7 @@ class Constants(MongoFilter.Constants): # type: ignore[name-defined]
164169
search_model_fields = ["name", "email"] # noqa: RUF012
165170
search_field_name = "search"
166171
ordering_field_name = "order_by"
172+
extra_fields = ["is_not_active"]
167173

168174
yield UserFilter
169175

tests/test_beanie/test_filter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
[{"address": {"city": "San Francisco"}}, 1],
2323
[{"search": "Mr"}, 2],
2424
[{"search": "mr"}, 2],
25+
[{"is_not_active": True}, 2],
26+
[{"is_not_active": False}, 4],
27+
[{"is_not_active": None}, 6],
28+
[{"gender": "O"}, 0],
2529
],
2630
)
2731
@pytest.mark.usefixtures("sports", "users")
2832
@pytest.mark.asyncio
2933
async def test_basic_filter(User, UserFilter, AddressFilter, filter_, expected_count):
30-
query = UserFilter(**filter_).filter(User.find({}))
34+
query = User.find({})
35+
user_filter = UserFilter(**filter_)
36+
37+
if user_filter.is_not_active is not None:
38+
query = query.find({"is_active": {"$ne": user_filter.is_not_active}})
39+
40+
query = user_filter.filter(query)
3141
assert await query.count() == expected_count
3242

3343

tests/test_mongoengine/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class User(Document):
5353
name = fields.StringField(null=True)
5454
email = fields.EmailField()
5555
age = fields.IntField()
56+
is_active = fields.BooleanField(default=True)
5657
address = fields.ReferenceField(Address)
5758
favorite_sports = fields.ListField(fields.ReferenceField(Sport))
5859

@@ -94,12 +95,14 @@ def users(User, Address, sports):
9495
User(
9596
name=None,
9697
age=21,
98+
is_active=False,
9799
created_at=datetime(2021, 12, 1),
98100
favorite_sports=sports,
99101
).save(),
100102
User(
101103
name="Mr Praline",
102104
age=33,
105+
is_active=False,
103106
created_at=datetime(2021, 12, 1),
104107
address=Address(street="22 rue Bellier", city="Nantes", country="France").save(),
105108
favorite_sports=[sports[0]],
@@ -162,6 +165,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
162165
age__gt: Optional[int] = None
163166
age__gte: Optional[int] = None
164167
age__in: Optional[list[int]] = None
168+
is_not_active: Optional[bool] = None
169+
gender: Optional[str] = None
165170
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
166171
with_prefix("address", AddressFilter)
167172
)
@@ -172,6 +177,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined]
172177
search_model_fields = ["name", "email"]
173178
search_field_name = "search"
174179
ordering_field_name = "order_by"
180+
extra_fields = ["is_not_active"]
175181

176182
yield UserFilter
177183

tests/test_mongoengine/test_filter.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from urllib.parse import urlencode
22

3+
import mongoengine
34
import pytest
45
from fastapi import status
56

@@ -22,11 +23,20 @@
2223
[{"address": {"city": "San Francisco"}}, 1],
2324
[{"search": "Mr"}, 2],
2425
[{"search": "mr"}, 2],
26+
[{"is_not_active": True}, 2],
27+
[{"is_not_active": False}, 4],
28+
[{"is_not_active": None}, 6],
2529
],
2630
)
2731
@pytest.mark.usefixtures("Address", "users")
2832
def test_basic_filter(User, UserFilter, filter_, expected_count):
29-
query = UserFilter(**filter_).filter(User.objects())
33+
query = User.objects()
34+
user_filter = UserFilter(**filter_)
35+
36+
if user_filter.is_not_active is not None:
37+
query = query.filter(is_active__ne=user_filter.is_not_active)
38+
39+
query = user_filter.filter(query)
3040
assert query.count() == expected_count
3141

3242

@@ -76,3 +86,9 @@ async def test_required_filter(test_client, filter_, expected_status_code):
7686
error_json = response.json()
7787
assert "detail" in error_json
7888
assert isinstance(error_json["detail"], list)
89+
90+
91+
@pytest.mark.usefixtures("users")
92+
def test_raise_invalid_query_error(User, UserFilter):
93+
with pytest.raises(mongoengine.errors.InvalidQueryError, match='Cannot resolve field "gender"'):
94+
UserFilter(gender="F").filter(User.objects()).count()

tests/test_sqlalchemy/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class User(Base): # type: ignore[misc, valid-type]
6363
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
6464
name = Column(String)
6565
age = Column(Integer, nullable=False)
66+
is_active = Column(Boolean, nullable=False, default=True)
6667
address_id = Column(Integer, ForeignKey("addresses.id"))
6768
address: Mapped[Address] = relationship(Address, backref="users", lazy="joined") # type: ignore[valid-type]
6869
favorite_sports: Mapped[Sport] = relationship( # type: ignore[valid-type]
@@ -117,11 +118,13 @@ async def users(session, User, Address):
117118
User(
118119
name=None,
119120
age=21,
121+
is_active=False,
120122
created_at=datetime(2021, 12, 1),
121123
),
122124
User(
123125
name="Mr Praline",
124126
age=33,
127+
is_active=False,
125128
created_at=datetime(2021, 12, 1),
126129
address=Address(street="22 rue Bellier", city="Nantes", country="France"),
127130
),
@@ -274,6 +277,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
274277
age__gt: Optional[int] = None
275278
age__gte: Optional[int] = None
276279
age__in: Optional[list[int]] = None
280+
gender: Optional[str] = None
281+
is_not_active: Optional[bool] = None
277282
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
278283
with_prefix("address", AddressFilter), by_alias=True
279284
)
@@ -284,6 +289,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined]
284289
model = User
285290
search_model_fields = ["name"]
286291
search_field_name = "search"
292+
extra_fields = ["is_not_active"]
287293

288294
yield UserFilter
289295

tests/test_sqlalchemy/test_filter.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,21 @@
3535
[{"address_id__isnull": True}, 1],
3636
[{"search": "Mr"}, 2],
3737
[{"search": "mr"}, 2],
38+
[{"is_not_active": True}, 2],
39+
[{"is_not_active": False}, 4],
40+
[{"is_not_active": None}, 6],
3841
],
3942
)
4043
@pytest.mark.usefixtures("users")
4144
@pytest.mark.asyncio
4245
async def test_filter(session, Address, User, UserFilter, filter_, expected_count):
4346
query = select(User).outerjoin(Address)
44-
query = UserFilter(**filter_).filter(query)
47+
user_filter = UserFilter(**filter_)
48+
49+
if user_filter.is_not_active is not None:
50+
query = query.filter(User.is_active.is_(not user_filter.is_not_active))
51+
52+
query = user_filter.filter(query)
4553
result = await session.execute(query)
4654
assert len(result.scalars().unique().all()) == expected_count
4755

@@ -111,3 +119,14 @@ async def test_required_filter(test_client, filter_, expected_status_code):
111119
error_json = response.json()
112120
assert "detail" in error_json
113121
assert isinstance(error_json["detail"], list)
122+
123+
124+
@pytest.mark.usefixtures("users")
125+
@pytest.mark.asyncio
126+
async def test_raise_attribute_error(session, User, UserFilter):
127+
with pytest.raises(AttributeError, match="type object 'User' has no attribute 'gender'"):
128+
query = select(User)
129+
user_filter = UserFilter(gender="M")
130+
131+
query = user_filter.filter(query)
132+
await session.execute(query)

0 commit comments

Comments
 (0)