Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added pagination for azure ai search retriever #29525

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions libs/community/langchain_community/retrievers/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
)
return values

def _build_search_url(self, query: str) -> str:
def _build_search_url(self, query: str, skip: int = 0) -> str:
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
if url_suffix in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}/"
Expand All @@ -139,9 +139,13 @@
# pass to Azure to throw a specific error
base_url = self.service_name
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
top_param = f"&$top={self.top_k}" if self.top_k else ""
batch_size = self.top_k if self.top_k is not None else 1000

top_param = f"&$top={batch_size}"
filter_param = f"&$filter={self.filter}" if self.filter else ""
return base_url + endpoint_path + f"&search={query}" + top_param + filter_param
skip_param = f"&$skip={skip}"
count_param = "&$count=true"
return base_url + endpoint_path + f"&search={query}" + top_param + skip_param + filter_param + count_param

Check failure on line 148 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:148:89: E501 Line too long (114 > 88)

Check failure on line 148 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:148:89: E501 Line too long (114 > 88)

@property
def _headers(self) -> Dict[str, str]:
Expand All @@ -151,26 +155,64 @@
}

def _search(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")

return json.loads(response.text)["value"]

all_results = []
skip = 0

while True:
search_url = self._build_search_url(query, skip)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")

response_json = json.loads(response.text)
current_results = response_json.get('value', [])

all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)

return all_results

async def _asearch(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
all_results = []
skip = 0

if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(search_url, headers=self._headers) as response:
response_json = await response.json()
while True:
search_url = self._build_search_url(query, skip)
async with session.get(search_url, headers=self._headers) as response:

Check failure on line 188 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:188:89: E501 Line too long (90 > 88)

Check failure on line 188 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:188:89: E501 Line too long (90 > 88)
response_json = await response.json()

current_results = response_json.get('value', [])
all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)
else:
async with self.aiosession.get(
search_url, headers=self._headers
) as response:
response_json = await response.json()

return response_json["value"]
async with self.aiosession:
while True:
search_url = self._build_search_url(query, skip)
async with self.aiosession.get(search_url, headers=self._headers) as response:

Check failure on line 203 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:203:89: E501 Line too long (98 > 88)

Check failure on line 203 in libs/community/langchain_community/retrievers/azure_ai_search.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/retrievers/azure_ai_search.py:203:89: E501 Line too long (98 > 88)
response_json = await response.json()

current_results = response_json.get('value', [])
all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)

return all_results

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
Expand Down
Loading