Skip to content

feat: allow extra filtering fields #619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Add querystring filters to your api endpoints and show them in the swagger UI.

The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy),
[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie).
[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie).

## Example

Expand Down Expand Up @@ -133,6 +133,34 @@ class UserFilter(Filter):
address__country: Optional[str]
```

### Extra fields

Sometimes, you may need to add extra fields to your filter for custom logic. To do this, follow these steps:

1. Define the extra fields in your filter class.
2. List these fields in the `extra_fields` constant.
3. Access these fields in your endpoint to implement custom filtering logic.

```python
class UserFilter(Filter):
name: Optional[str]
is_not_active: Optional[bool]

class Constants(Filter.Constants):
model = User
extra_fields = ["is_not_active"]

@app.get("/users", response_model=list[UserOut])
async def get_users(user_filter: UserFilter = FilterDepends(UserFilter), db: AsyncSession = Depends(get_db)) -> Any:
query = user_filter.filter(select(User))

if user_filter.is_not_active is not None:
query = query.where(User.is_active.is_(not user_filter.is_not_active))

result = await db.execute(query)
return result.scalars().all()
```

## Order by

There is a specific field on the filter class that can be used for ordering. The default name is `order_by` and it
Expand Down
3 changes: 3 additions & 0 deletions fastapi_filter/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Constants: # pragma: no cover
ordering_field_name: str = "order_by"
search_model_fields: list[str]
search_field_name: str = "search"
extra_fields: list[str] = []
prefix: str
original_filter: type["BaseFilterModel"]

Expand All @@ -59,6 +60,8 @@ def filter(self, query): # pragma: no cover
def filtering_fields(self):
fields = self.model_dump(exclude_none=True, exclude_unset=True)
fields.pop(self.Constants.ordering_field_name, None)
for field_name in self.Constants.extra_fields:
fields.pop(field_name, None)
return fields.items()

def sort(self, query): # pragma: no cover
Expand Down
6 changes: 6 additions & 0 deletions tests/test_beanie/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class User(Document):
name: Optional[str] = None
email: Optional[EmailStr] = None
age: int
is_active: bool = True
address: Optional[Link[Address]] = None
favorite_sports: Optional[list[Link[Sport]]] = []

Expand Down Expand Up @@ -86,12 +87,14 @@ async def users(
await User(
name=None,
age=21,
is_active=False,
created_at=datetime(2021, 12, 1),
favorite_sports=sports,
).save(link_rule=WriteRules.WRITE),
await User(
name="Mr Praline",
age=33,
is_active=False,
created_at=datetime(2021, 12, 1),
address=Address(street="22 rue Bellier", city="Nantes", country="France"),
favorite_sports=[sports[0]],
Expand Down Expand Up @@ -154,6 +157,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
age__gt: Optional[int] = None
age__gte: Optional[int] = None
age__in: Optional[list[int]] = None
gender: Optional[str] = None
is_not_active: Optional[bool] = None
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
with_prefix("address", AddressFilter),
)
Expand All @@ -164,6 +169,7 @@ class Constants(MongoFilter.Constants): # type: ignore[name-defined]
search_model_fields = ["name", "email"] # noqa: RUF012
search_field_name = "search"
ordering_field_name = "order_by"
extra_fields = ["is_not_active"]

yield UserFilter

Expand Down
12 changes: 11 additions & 1 deletion tests/test_beanie/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@
[{"address": {"city": "San Francisco"}}, 1],
[{"search": "Mr"}, 2],
[{"search": "mr"}, 2],
[{"is_not_active": True}, 2],
[{"is_not_active": False}, 4],
[{"is_not_active": None}, 6],
[{"gender": "O"}, 0],
],
)
@pytest.mark.usefixtures("sports", "users")
@pytest.mark.asyncio
async def test_basic_filter(User, UserFilter, AddressFilter, filter_, expected_count):
query = UserFilter(**filter_).filter(User.find({}))
query = User.find({})
user_filter = UserFilter(**filter_)

if user_filter.is_not_active is not None:
query = query.find({"is_active": {"$ne": user_filter.is_not_active}})

query = user_filter.filter(query)
assert await query.count() == expected_count


Expand Down
6 changes: 6 additions & 0 deletions tests/test_mongoengine/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class User(Document):
name = fields.StringField(null=True)
email = fields.EmailField()
age = fields.IntField()
is_active = fields.BooleanField(default=True)
address = fields.ReferenceField(Address)
favorite_sports = fields.ListField(fields.ReferenceField(Sport))

Expand Down Expand Up @@ -94,12 +95,14 @@ def users(User, Address, sports):
User(
name=None,
age=21,
is_active=False,
created_at=datetime(2021, 12, 1),
favorite_sports=sports,
).save(),
User(
name="Mr Praline",
age=33,
is_active=False,
created_at=datetime(2021, 12, 1),
address=Address(street="22 rue Bellier", city="Nantes", country="France").save(),
favorite_sports=[sports[0]],
Expand Down Expand Up @@ -162,6 +165,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
age__gt: Optional[int] = None
age__gte: Optional[int] = None
age__in: Optional[list[int]] = None
is_not_active: Optional[bool] = None
gender: Optional[str] = None
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
with_prefix("address", AddressFilter)
)
Expand All @@ -172,6 +177,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined]
search_model_fields = ["name", "email"]
search_field_name = "search"
ordering_field_name = "order_by"
extra_fields = ["is_not_active"]

yield UserFilter

Expand Down
18 changes: 17 additions & 1 deletion tests/test_mongoengine/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from urllib.parse import urlencode

import mongoengine
import pytest
from fastapi import status

Expand All @@ -22,11 +23,20 @@
[{"address": {"city": "San Francisco"}}, 1],
[{"search": "Mr"}, 2],
[{"search": "mr"}, 2],
[{"is_not_active": True}, 2],
[{"is_not_active": False}, 4],
[{"is_not_active": None}, 6],
],
)
@pytest.mark.usefixtures("Address", "users")
def test_basic_filter(User, UserFilter, filter_, expected_count):
query = UserFilter(**filter_).filter(User.objects())
query = User.objects()
user_filter = UserFilter(**filter_)

if user_filter.is_not_active is not None:
query = query.filter(is_active__ne=user_filter.is_not_active)

query = user_filter.filter(query)
assert query.count() == expected_count


Expand Down Expand Up @@ -76,3 +86,9 @@ async def test_required_filter(test_client, filter_, expected_status_code):
error_json = response.json()
assert "detail" in error_json
assert isinstance(error_json["detail"], list)


@pytest.mark.usefixtures("users")
def test_raise_invalid_query_error(User, UserFilter):
with pytest.raises(mongoengine.errors.InvalidQueryError, match='Cannot resolve field "gender"'):
UserFilter(gender="F").filter(User.objects()).count()
6 changes: 6 additions & 0 deletions tests/test_sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class User(Base): # type: ignore[misc, valid-type]
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
name = Column(String)
age = Column(Integer, nullable=False)
is_active = Column(Boolean, nullable=False, default=True)
address_id = Column(Integer, ForeignKey("addresses.id"))
address: Mapped[Address] = relationship(Address, backref="users", lazy="joined") # type: ignore[valid-type]
favorite_sports: Mapped[Sport] = relationship( # type: ignore[valid-type]
Expand Down Expand Up @@ -117,11 +118,13 @@ async def users(session, User, Address):
User(
name=None,
age=21,
is_active=False,
created_at=datetime(2021, 12, 1),
),
User(
name="Mr Praline",
age=33,
is_active=False,
created_at=datetime(2021, 12, 1),
address=Address(street="22 rue Bellier", city="Nantes", country="France"),
),
Expand Down Expand Up @@ -274,6 +277,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type]
age__gt: Optional[int] = None
age__gte: Optional[int] = None
age__in: Optional[list[int]] = None
gender: Optional[str] = None
is_not_active: Optional[bool] = None
address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type]
with_prefix("address", AddressFilter), by_alias=True
)
Expand All @@ -284,6 +289,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined]
model = User
search_model_fields = ["name"]
search_field_name = "search"
extra_fields = ["is_not_active"]

yield UserFilter

Expand Down
21 changes: 20 additions & 1 deletion tests/test_sqlalchemy/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,21 @@
[{"address_id__isnull": True}, 1],
[{"search": "Mr"}, 2],
[{"search": "mr"}, 2],
[{"is_not_active": True}, 2],
[{"is_not_active": False}, 4],
[{"is_not_active": None}, 6],
],
)
@pytest.mark.usefixtures("users")
@pytest.mark.asyncio
async def test_filter(session, Address, User, UserFilter, filter_, expected_count):
query = select(User).outerjoin(Address)
query = UserFilter(**filter_).filter(query)
user_filter = UserFilter(**filter_)

if user_filter.is_not_active is not None:
query = query.filter(User.is_active.is_(not user_filter.is_not_active))

query = user_filter.filter(query)
result = await session.execute(query)
assert len(result.scalars().unique().all()) == expected_count

Expand Down Expand Up @@ -111,3 +119,14 @@ async def test_required_filter(test_client, filter_, expected_status_code):
error_json = response.json()
assert "detail" in error_json
assert isinstance(error_json["detail"], list)


@pytest.mark.usefixtures("users")
@pytest.mark.asyncio
async def test_raise_attribute_error(session, User, UserFilter):
with pytest.raises(AttributeError, match="type object 'User' has no attribute 'gender'"):
query = select(User)
user_filter = UserFilter(gender="M")

query = user_filter.filter(query)
await session.execute(query)
Loading