Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions src/bioclip_vector_db/client/nearest_neighbor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

logger = logging.getLogger(__name__)


class NearestNeighborClient:
"""A client for querying multiple LocalIndexServer instances."""

Expand All @@ -31,8 +30,18 @@ def _post_request(self, url: str, json_data: Dict[str, Any]) -> Dict[str, Any]:
logger.error(f"Request to {url} failed: {e}")
return {"status": "error", "error": {"message": str(e)}}

def _get_request(self, url: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Sends a GET request to a given URL and returns the JSON response."""
try:
response = requests.get(url, json=params)
response.raise_for_status() # Raise an exception for bad status codes
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Request to {url} failed: {e}")
return {"status": "error", "error": {"message": str(e)}}

def search(
self, query_vector: List[float], top_n: int = 10, nprobe: int = 1
self, query_vector: List[float], top_n: int = 10, nprobe: int = 1, fetch_metadata: bool = True
) -> List[Dict[str, Any]]:
"""
Queries all configured servers for the nearest neighbors.
Expand All @@ -56,8 +65,38 @@ def search(
results.append({"server": url, "response": result})
except Exception as e:
logger.error(f"Search for {url} failed: {e}")

merged_results = self._merge_results(results)

if fetch_metadata:
for result in merged_results:
image_id = result.get("id", None)
metadata_response = self.get_metadata(image_id, None)
if metadata_response and metadata_response.get("status") == "success":
result["metadata"] = metadata_response.get("data")

return merged_results

return self._merge_results(results)
def get_metadata(self, image_id: str, server_url: str = None) -> Dict[str, Any]:
"""
Retrieves metadata for a given image_id from one of the servers.

:param image_id: The image_id to retrieve metadata for.
:param server_url: The specific server to query. If None, a random server is chosen.
:return: The metadata response from the server.
"""
if server_url and server_url not in self._server_urls:
raise ValueError(f"Provided server_url '{server_url}' is not in the configured list of servers.")

get_payload = {"image_id": image_id}
target_server = server_url if server_url else random.choice(self._server_urls)
get_url = f"{target_server}/get"

try:
return self._get_request(get_url, params=get_payload)
except Exception as e:
logger.error(f"Get metadata for {image_id} from {get_url} failed: {e}")
return {"status": "error", "error": {"message": str(e)}}

def _merge_results(self, results):
"""
Expand Down Expand Up @@ -92,37 +131,4 @@ def health(self) -> List[Dict[str, Any]]:
"response": {"status": "error", "error": {"message": str(e)}},
}
)
return health_statuses


if __name__ == "__main__":
# Configure logging
logging.basicConfig(
level=logging.INFO, format="[%(asctime)s] [%(levelname)s] %(message)s"
)

# List of server URLs to query
SERVER_URLS = [f"http://0.0.0.0:{port}" for port in range(5001, 5004)]

# Initialize the client
client = NearestNeighborClient(SERVER_URLS)

# 1. Check the health of the servers
print("--- Checking server health ---")
health_results = client.health()
print(json.dumps(health_results, indent=2))

# 2. Perform a search
print("\n--- Performing search ---")

# Create a dummy query vector.
# IMPORTANT: The dimension of this vector must match the dimension of the vectors in the FAISS index.
# For BioCLIP models, this is often 512 or 768. We'll use 512 as an example.
DUMMY_VECTOR_DIM = 512
dummy_query_vector = [random.random() for _ in range(DUMMY_VECTOR_DIM)]

# Perform the search
search_results = client.search(query_vector=dummy_query_vector, top_n=1, nprobe=10)

# Print the results
print(json.dumps(search_results, indent=2))
return health_statuses
1 change: 1 addition & 0 deletions src/bioclip_vector_db/query/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export LEADER_INDEX=leader.index
export PARTITIONS="1,2,5-10"
export NPROBE=10
export PORT=5001
export USE_CACHE=1

gunicorn --workers ${WORKERS:-4} --bind 0.0.0.0:${PORT} --chdir src bioclip_vector_db.query.wsgi:app
```
35 changes: 33 additions & 2 deletions src/bioclip_vector_db/query/neighborhood_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ def _search_leader(self, query_vector: list, top_n: int) -> np.ndarray:

@timer
def search(
self, query_vector: list, top_n: int, nprobe: int = 1
) -> Dict[int, tuple[np.ndarray, np.ndarray]]:
self, query_vector: list, top_n: int, nprobe: int = 1) -> Dict[int, tuple[np.ndarray, np.ndarray]]:
"""
Performs a search on the loaded FAISS index.
This method first queries the leader index to identify the most relevant partitions,
Expand Down Expand Up @@ -243,6 +242,10 @@ def dimensions(self) -> int:
def get_nprobe(self) -> int:
return self._nprobe

@timer
def get(self, original_id: str):
return self._metadata_db.get_metadata(original_id)


class LocalIndexServer:
"""A Flask server class to handle search and health check requests."""
Expand All @@ -258,6 +261,7 @@ def _register_routes(self):
"/search", "search", self.handle_search, methods=["POST"]
)
self._app.add_url_rule("/health", "health", self.handle_health, methods=["GET"])
self._app.add_url_rule("/get", "get", self.handle_get, methods=["GET"])

def _success_response(self, data, status_code=200):
"""Generates a structured success JSON response."""
Expand Down Expand Up @@ -294,6 +298,33 @@ def handle_health(self):
# Use 503 Service Unavailable when the service is not ready
return self._error_response("Index not loaded or trained", 503)

def handle_get(self):
"""
Handler for the /get endpoint.
Retrieves metadata for a given image_id.
"""
data = request.get_json()
if not data or "image_id" not in data:
return self._error_response("Missing 'image_id' in request parameters", 400)
image_id = data["image_id"]
try:
metadata = self._service.get(image_id)
if metadata:
return self._success_response(metadata)
else:
return self._error_response(
f"Metadata not found for image_id: {image_id}", 404
)
except Exception as e:
logger.error(f"An error occurred during get: {e}", exc_info=True)
return self._error_response(
"An internal server error occurred during get", 500
)

def handle_get_random(self):
"""Handler for the /get_random endpoint."""
pass

def _handle_merging(self, results):
all_matches = [
match for matches_dict in results for match in matches_dict["matches"]
Expand Down
15 changes: 10 additions & 5 deletions src/bioclip_vector_db/storage/metadata_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def create_table(self):
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_original_id ON id_mapping (original_id)
"""
)
logger.info("SQLITE: Create table successful.")
except sqlite3.Error as e:
logger.error(f"Error creating table: {e}")
Expand Down Expand Up @@ -139,7 +144,7 @@ def get_metadata(self, partition_id: int, faiss_id: int) -> Optional[Dict[str, A
try:
cursor = conn.cursor()
cursor.execute(
"SELECT original_id, metadata FROM id_mapping WHERE partition_id = ? AND faiss_id = ?",
"SELECT metadata FROM id_mapping WHERE partition_id = ? AND faiss_id = ?",
(int(partition_id), int(faiss_id)),
)
result = cursor.fetchone()
Expand All @@ -150,7 +155,7 @@ def get_metadata(self, partition_id: int, faiss_id: int) -> Optional[Dict[str, A
logger.error(f"Error getting metadata: {e}")
raise

def get_metadata(self, original_id: int) -> Optional[Dict[str, Any]]:
def get_metadata(self, original_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieves the metadata for a given original ID.

Expand All @@ -161,12 +166,12 @@ def get_metadata(self, original_id: int) -> Optional[Dict[str, Any]]:
try:
cursor = conn.cursor()
cursor.execute(
"SELECT original_id, metadata FROM id_mapping WHERE original_id = ?",
(str(original_id)),
"SELECT metadata FROM id_mapping WHERE original_id = ?",
(original_id,),
)
result = cursor.fetchone()
if result and result[0]:
return json.loads(result[0])
return json.loads(result[0].decode("utf-8"))
return None
except sqlite3.Error as e:
logger.error(f"Error getting metadata: {e}")
Expand Down