Skip to content

Commit 8e5ebfa

Browse files
fix type for advanced freetext and allow free-text for Item search (#263)
* fix type for advanced freetext * add tests and remove free-text from method annotations * add advanced tests * add failing tests * fix and enable free-text for items
1 parent 25f5fa7 commit 8e5ebfa

File tree

9 files changed

+223
-20
lines changed

9 files changed

+223
-20
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Fixed
66

77
- fix root-path handling when setting via env var or on app instance
8+
- Allow `q` parameter to be a `str` not a `list[str]` for Advanced Free-Text extension
89

910
### Changed
1011

stac_fastapi/pgstac/app.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
CollectionSearchExtension,
2525
CollectionSearchFilterExtension,
2626
FieldsExtension,
27-
FreeTextExtension,
2827
ItemCollectionFilterExtension,
2928
OffsetPaginationExtension,
3029
SearchFilterExtension,
@@ -42,7 +41,7 @@
4241
from stac_fastapi.pgstac.config import Settings
4342
from stac_fastapi.pgstac.core import CoreCrudClient, health_check
4443
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
45-
from stac_fastapi.pgstac.extensions import QueryExtension
44+
from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension
4645
from stac_fastapi.pgstac.extensions.filter import FiltersClient
4746
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
4847
from stac_fastapi.pgstac.types.search import PgstacSearch

stac_fastapi/pgstac/core.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ async def all_collections( # noqa: C901
5454
sortby: Optional[str] = None,
5555
filter_expr: Optional[str] = None,
5656
filter_lang: Optional[str] = None,
57-
q: Optional[List[str]] = None,
58-
**kwargs,
57+
**kwargs: Any,
5958
) -> Collections:
6059
"""Cross catalog search (GET).
6160
@@ -86,9 +85,15 @@ async def all_collections( # noqa: C901
8685
sortby=sortby,
8786
filter_query=filter_expr,
8887
filter_lang=filter_lang,
89-
q=q,
88+
**kwargs,
9089
)
9190

91+
# NOTE: `FreeTextExtension` - pgstac will only accept `str` so we need to
92+
# join the list[str] with ` OR `
93+
# ref: https://github.com/stac-utils/stac-fastapi-pgstac/pull/263
94+
if q := clean_args.pop("q", None):
95+
clean_args["q"] = " OR ".join(q) if isinstance(q, list) else q
96+
9297
async with request.app.state.get_connection(request, "r") as conn:
9398
q, p = render(
9499
"""
@@ -157,7 +162,10 @@ async def all_collections( # noqa: C901
157162
)
158163

159164
async def get_collection(
160-
self, collection_id: str, request: Request, **kwargs
165+
self,
166+
collection_id: str,
167+
request: Request,
168+
**kwargs: Any,
161169
) -> Collection:
162170
"""Get collection by id.
163171
@@ -202,7 +210,9 @@ async def get_collection(
202210
return Collection(**collection)
203211

204212
async def _get_base_item(
205-
self, collection_id: str, request: Request
213+
self,
214+
collection_id: str,
215+
request: Request,
206216
) -> Dict[str, Any]:
207217
"""Get the base item of a collection for use in rehydrating full item collection properties.
208218
@@ -359,7 +369,7 @@ async def item_collection(
359369
filter_expr: Optional[str] = None,
360370
filter_lang: Optional[str] = None,
361371
token: Optional[str] = None,
362-
**kwargs,
372+
**kwargs: Any,
363373
) -> ItemCollection:
364374
"""Get all items from a specific collection.
365375
@@ -391,6 +401,7 @@ async def item_collection(
391401
filter_lang=filter_lang,
392402
fields=fields,
393403
sortby=sortby,
404+
**kwargs,
394405
)
395406

396407
try:
@@ -417,7 +428,11 @@ async def item_collection(
417428
return ItemCollection(**item_collection)
418429

419430
async def get_item(
420-
self, item_id: str, collection_id: str, request: Request, **kwargs
431+
self,
432+
item_id: str,
433+
collection_id: str,
434+
request: Request,
435+
**kwargs: Any,
421436
) -> Item:
422437
"""Get item by id.
423438
@@ -445,7 +460,10 @@ async def get_item(
445460
return Item(**item_collection["features"][0])
446461

447462
async def post_search(
448-
self, search_request: PgstacSearch, request: Request, **kwargs
463+
self,
464+
search_request: PgstacSearch,
465+
request: Request,
466+
**kwargs: Any,
449467
) -> ItemCollection:
450468
"""Cross catalog search (POST).
451469
@@ -489,7 +507,7 @@ async def get_search(
489507
filter_expr: Optional[str] = None,
490508
filter_lang: Optional[str] = None,
491509
token: Optional[str] = None,
492-
**kwargs,
510+
**kwargs: Any,
493511
) -> ItemCollection:
494512
"""Cross catalog search (GET).
495513
@@ -516,6 +534,7 @@ async def get_search(
516534
sortby=sortby,
517535
filter_query=filter_expr,
518536
filter_lang=filter_lang,
537+
**kwargs,
519538
)
520539

521540
try:
@@ -550,7 +569,8 @@ def _clean_search_args( # noqa: C901
550569
sortby: Optional[str] = None,
551570
filter_query: Optional[str] = None,
552571
filter_lang: Optional[str] = None,
553-
q: Optional[List[str]] = None,
572+
q: Optional[Union[str, List[str]]] = None,
573+
**kwargs: Any,
554574
) -> Dict[str, Any]:
555575
"""Clean up search arguments to match format expected by pgstac"""
556576
if filter_query:
@@ -596,7 +616,7 @@ def _clean_search_args( # noqa: C901
596616
base_args["fields"] = {"include": includes, "exclude": excludes}
597617

598618
if q:
599-
base_args["q"] = " OR ".join(q)
619+
base_args["q"] = q
600620

601621
# Remove None values from dict
602622
clean = {}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""pgstac extension customisations."""
22

33
from .filter import FiltersClient
4+
from .free_text import FreeTextExtension
45
from .query import QueryExtension
56

6-
__all__ = ["QueryExtension", "FiltersClient"]
7+
__all__ = ["QueryExtension", "FiltersClient", "FreeTextExtension"]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Free-Text model for PgSTAC."""
2+
3+
from typing import List, Optional
4+
5+
from pydantic import BaseModel, Field
6+
from pydantic.functional_serializers import PlainSerializer
7+
from stac_fastapi.extensions.core.free_text import (
8+
FreeTextExtension as FreeTextExtensionBase,
9+
)
10+
from typing_extensions import Annotated
11+
12+
13+
class FreeTextExtensionPostRequest(BaseModel):
14+
"""Free-text Extension POST request model."""
15+
16+
q: Annotated[
17+
Optional[List[str]],
18+
PlainSerializer(lambda x: " OR ".join(x), return_type=str, when_used="json"),
19+
] = Field(
20+
None,
21+
description="Parameter to perform free-text queries against STAC metadata",
22+
)
23+
24+
25+
class FreeTextExtension(FreeTextExtensionBase):
26+
"""FreeText Extension.
27+
28+
Override the POST request model to add custom serialization
29+
"""
30+
31+
POST = FreeTextExtensionPostRequest

tests/conftest.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import logging
33
import os
4-
import time
54
from typing import Callable, Dict
65
from urllib.parse import quote_plus as quote
76
from urllib.parse import urljoin
@@ -26,7 +25,7 @@
2625
CollectionSearchExtension,
2726
CollectionSearchFilterExtension,
2827
FieldsExtension,
29-
FreeTextExtension,
28+
FreeTextAdvancedExtension,
3029
ItemCollectionFilterExtension,
3130
OffsetPaginationExtension,
3231
SearchFilterExtension,
@@ -44,7 +43,7 @@
4443
from stac_fastapi.pgstac.config import PostgresSettings, Settings
4544
from stac_fastapi.pgstac.core import CoreCrudClient, health_check
4645
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
47-
from stac_fastapi.pgstac.extensions import QueryExtension
46+
from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension
4847
from stac_fastapi.pgstac.extensions.filter import FiltersClient
4948
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
5049
from stac_fastapi.pgstac.types.search import PgstacSearch
@@ -139,6 +138,7 @@ def api_client(request):
139138
FieldsExtension(),
140139
SearchFilterExtension(client=FiltersClient()),
141140
TokenPaginationExtension(),
141+
FreeTextExtension(), # not recommended by PgSTAC
142142
]
143143
application_extensions.extend(search_extensions)
144144

@@ -167,6 +167,7 @@ def api_client(request):
167167
FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
168168
ItemCollectionFilterExtension(client=FiltersClient()),
169169
TokenPaginationExtension(),
170+
FreeTextExtension(), # not recommended by PgSTAC
170171
]
171172
application_extensions.extend(item_collection_extensions)
172173

@@ -207,7 +208,6 @@ async def app(api_client, database):
207208
pgdatabase=database.dbname,
208209
)
209210
logger.info("Creating app Fixture")
210-
time.time()
211211
app = api_client.app
212212
await connect_to_db(
213213
app,
@@ -314,7 +314,6 @@ async def app_no_ext(database):
314314
pgdatabase=database.dbname,
315315
)
316316
logger.info("Creating app Fixture")
317-
time.time()
318317
await connect_to_db(
319318
api_client_no_ext.app,
320319
postgres_settings=postgres_settings,
@@ -354,7 +353,6 @@ async def app_no_transaction(database):
354353
pgdatabase=database.dbname,
355354
)
356355
logger.info("Creating app Fixture")
357-
time.time()
358356
await connect_to_db(
359357
api.app,
360358
postgres_settings=postgres_settings,
@@ -402,3 +400,57 @@ async def default_client(default_app):
402400
transport=ASGITransport(app=default_app), base_url="http://test"
403401
) as c:
404402
yield c
403+
404+
405+
@pytest.fixture(scope="function")
406+
async def app_advanced_freetext(database):
407+
"""Default stac-fastapi-pgstac application without only the transaction extensions."""
408+
api_settings = Settings(testing=True)
409+
410+
application_extensions = [
411+
TransactionExtension(client=TransactionsClient(), settings=api_settings)
412+
]
413+
414+
collection_extensions = [
415+
FreeTextAdvancedExtension(),
416+
OffsetPaginationExtension(),
417+
]
418+
collection_search_extension = CollectionSearchExtension.from_extensions(
419+
collection_extensions
420+
)
421+
application_extensions.append(collection_search_extension)
422+
423+
app = StacApi(
424+
settings=api_settings,
425+
extensions=application_extensions,
426+
client=CoreCrudClient(),
427+
health_check=health_check,
428+
collections_get_request_model=collection_search_extension.GET,
429+
)
430+
431+
postgres_settings = PostgresSettings(
432+
pguser=database.user,
433+
pgpassword=database.password,
434+
pghost=database.host,
435+
pgport=database.port,
436+
pgdatabase=database.dbname,
437+
)
438+
logger.info("Creating app Fixture")
439+
await connect_to_db(
440+
app.app,
441+
postgres_settings=postgres_settings,
442+
add_write_connection_pool=True,
443+
)
444+
yield app.app
445+
await close_db_connection(app.app)
446+
447+
logger.info("Closed Pools.")
448+
449+
450+
@pytest.fixture(scope="function")
451+
async def app_client_advanced_freetext(app_advanced_freetext):
452+
logger.info("creating app_client")
453+
async with AsyncClient(
454+
transport=ASGITransport(app=app_advanced_freetext), base_url="http://test"
455+
) as c:
456+
yield c

tests/data/test_item.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"type": "Polygon"
3535
},
3636
"properties": {
37+
"description": "Landat 8 imagery radiometrically calibrated and orthorectified using gound points and Digital Elevation Model (DEM) data to correct relief displacement.",
3738
"datetime": "2020-02-12T12:30:22Z",
3839
"landsat:scene_id": "LC82081612020043LGN00",
3940
"landsat:row": "161",

tests/resources/test_collection.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,71 @@ async def test_collection_search_freetext(
365365
assert resp.json()["collections"][0]["id"] == load_test2_collection.id
366366

367367
resp = await app_client.get(
368+
"/collections",
369+
params={"q": "temperature,calibrated"},
370+
)
371+
assert resp.json()["numberReturned"] == 2
372+
assert resp.json()["numberMatched"] == 2
373+
assert len(resp.json()["collections"]) == 2
374+
375+
resp = await app_client.get(
376+
"/collections",
377+
params={"q": "temperature,yo"},
378+
)
379+
assert resp.json()["numberReturned"] == 1
380+
assert resp.json()["numberMatched"] == 1
381+
assert len(resp.json()["collections"]) == 1
382+
assert resp.json()["collections"][0]["id"] == load_test2_collection.id
383+
384+
resp = await app_client.get(
385+
"/collections",
386+
params={"q": "nosuchthing"},
387+
)
388+
assert len(resp.json()["collections"]) == 0
389+
390+
391+
@requires_pgstac_0_9_2
392+
@pytest.mark.asyncio
393+
async def test_collection_search_freetext_advanced(
394+
app_client_advanced_freetext, load_test_collection, load_test2_collection
395+
):
396+
# free-text
397+
resp = await app_client_advanced_freetext.get(
398+
"/collections",
399+
params={"q": "temperature"},
400+
)
401+
assert resp.json()["numberReturned"] == 1
402+
assert resp.json()["numberMatched"] == 1
403+
assert len(resp.json()["collections"]) == 1
404+
assert resp.json()["collections"][0]["id"] == load_test2_collection.id
405+
406+
resp = await app_client_advanced_freetext.get(
407+
"/collections",
408+
params={"q": "temperature,calibrated"},
409+
)
410+
assert resp.json()["numberReturned"] == 2
411+
assert resp.json()["numberMatched"] == 2
412+
assert len(resp.json()["collections"]) == 2
413+
414+
resp = await app_client_advanced_freetext.get(
415+
"/collections",
416+
params={"q": "temperature,yo"},
417+
)
418+
assert resp.json()["numberReturned"] == 1
419+
assert resp.json()["numberMatched"] == 1
420+
assert len(resp.json()["collections"]) == 1
421+
assert resp.json()["collections"][0]["id"] == load_test2_collection.id
422+
423+
resp = await app_client_advanced_freetext.get(
424+
"/collections",
425+
params={"q": "temperature OR yo"},
426+
)
427+
assert resp.json()["numberReturned"] == 1
428+
assert resp.json()["numberMatched"] == 1
429+
assert len(resp.json()["collections"]) == 1
430+
assert resp.json()["collections"][0]["id"] == load_test2_collection.id
431+
432+
resp = await app_client_advanced_freetext.get(
368433
"/collections",
369434
params={"q": "nosuchthing"},
370435
)

0 commit comments

Comments
 (0)