-
Notifications
You must be signed in to change notification settings - Fork 8
FAISS/vector search #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dhyana6466
wants to merge
4
commits into
oss-slu:main
Choose a base branch
from
dhyana6466:feat/vector-search-faiss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import { NextResponse } from "next/server"; | ||
export async function GET() { | ||
return NextResponse.json({ ok: true }); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// app/api/search/route.ts | ||
import { NextRequest, NextResponse } from "next/server"; | ||
import { vectorSearch } from "@/app/lib/vectorIndex"; | ||
import { embedById } from "@/app/lib/embeddings"; | ||
|
||
type Body = { imageId?: string; vector?: unknown; topK?: unknown }; | ||
|
||
function isNumberArray(x: unknown): x is number[] { | ||
return Array.isArray(x) && x.every((n) => typeof n === "number"); | ||
} | ||
|
||
export async function POST(req: NextRequest) { | ||
try { | ||
const body: Body = await req.json(); | ||
|
||
// topK: default 10, clamp to [1, 100] | ||
let topK = 10; | ||
if (typeof body.topK === "number") { | ||
topK = Math.min(Math.max(Math.floor(body.topK), 1), 100); | ||
} | ||
|
||
// Building the query vector | ||
let query: number[]; | ||
if (typeof body.imageId === "string" && body.imageId.length > 0) { | ||
query = await embedById(body.imageId); // normalized inside | ||
} else if (isNumberArray(body.vector)) { | ||
query = body.vector; | ||
} else { | ||
return NextResponse.json({ error: "imageId or vector required" }, { status: 400 }); | ||
} | ||
|
||
const results = await vectorSearch(query, topK); | ||
return NextResponse.json({ results }); | ||
} catch (err: unknown) { | ||
const message = err instanceof Error ? err.message : "search_failed"; | ||
return NextResponse.json({ error: message }, { status: 500 }); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Small helper around Universal Model Adapter | ||
// Ensures the vector is L2-normalized (cosine works correctly) | ||
|
||
import { getImageEmbeddings } from "@/app/lib/modelClient"; | ||
|
||
/** | ||
* get embedding vector for an image id (file path, url, or blob) | ||
* ensures result is 1-D array of floats, L2-normalized | ||
*/ | ||
export async function embedById(image: Blob | string): Promise<number[]> { | ||
// Run through modelClient | ||
const raw = await getImageEmbeddings(image); | ||
|
||
// Flattern nested arrays if pipeline returned [ [ [ ... ] ] ] | ||
let v: number[] = []; | ||
if (Array.isArray(raw)) { | ||
v = raw.flat(Infinity) as number[]; | ||
} else { | ||
throw new Error("embedding result is not array"); | ||
} | ||
|
||
return v; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,37 @@ | ||
// // lib/modelClient.ts | ||
// "use client"; | ||
// app/lib/modelClient.ts | ||
|
||
// import { pipeline } from "@xenova/transformers"; | ||
export type EmbeddingRequest = { imageId?: string; dataUrl?: string }; | ||
export type EmbeddingResponse = { vector: number[] }; | ||
|
||
// /** Singleton references to loaded pipelines */ | ||
// let imageEmbedder: any = null; | ||
function isEmbeddingResponse(x: unknown): x is EmbeddingResponse { | ||
const r = x as { vector?: unknown }; | ||
return Array.isArray(r?.vector) && r.vector.every((n) => typeof n === "number"); | ||
} | ||
|
||
// /** | ||
// * loading a CLIP embedding pipeline for image search (example). | ||
// * this will make it easy to use other models | ||
// */ | ||
// export async function loadEmbeddingPipeline() { | ||
// if (!imageEmbedder) { | ||
// // Using CLIP for image embeddings, just as an example | ||
// imageEmbedder = await pipeline( | ||
// "feature-extraction", | ||
// "Xenova/clip-vit-base-patch32" | ||
// ); | ||
// } | ||
// return imageEmbedder; | ||
// } | ||
/** Getying an embedding for an image by its ID (e.g., S3 key) via your embed API. */ | ||
export async function getEmbeddingForImageId(id: string): Promise<number[]> { | ||
const res = await fetch("/api/model/embed", { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ imageId: id } as EmbeddingRequest), | ||
cache: "no-store", | ||
}); | ||
|
||
// /** | ||
// * Extract embeddings from an image. | ||
// * Returns a vector (array of floats) we can then compare with our dataset. | ||
// */ | ||
// export async function getImageEmbeddings(image: Blob | string) { | ||
// const embedder = await loadEmbeddingPipeline(); | ||
// // The pipeline returns a nested array. We'll flatten or keep it nested as needed. | ||
// const result = await embedder(image); | ||
// return result; | ||
// } | ||
const json: unknown = await res.json(); | ||
if (!isEmbeddingResponse(json)) throw new Error("Bad embed response (imageId)"); | ||
return json.vector; | ||
} | ||
|
||
/** Getting an embedding for raw image data (e.g., data URL) via your embed API. */ | ||
export async function getEmbeddingForImageData(dataUrl: string): Promise<number[]> { | ||
const res = await fetch("/api/model/embed", { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ dataUrl } as EmbeddingRequest), | ||
cache: "no-store", | ||
}); | ||
|
||
const json: unknown = await res.json(); | ||
if (!isEmbeddingResponse(json)) throw new Error("Bad embed response (dataUrl)"); | ||
return json.vector; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// FAISS-only client used by Next.js API routes | ||
// Endpoints expected on the FAISS service: /health, /upsert, /delete, /search | ||
|
||
export type SearchHit = { | ||
id: string; // image id (e.g., S3 key) | ||
score: number; // similarity score (cosine via inner product) | ||
metadata?: Record<string, unknown>; | ||
}; | ||
|
||
type UpsertItem = { | ||
id: string; | ||
vector: number[]; // raw embedding | ||
metadata?: Record<string, unknown>; | ||
}; | ||
|
||
// Config | ||
function num(env: string | undefined, fallback: number): number { | ||
const n = Number(env); | ||
return Number.isFinite(n) && n > 0 ? n : fallback; | ||
} | ||
|
||
// Reading from env | ||
const FAISS_URL = (process.env.FAISS_URL || "http://127.0.0.1:8000").replace(/\/+$/, ""); | ||
const EMBEDDING_DIM = num(process.env.EMBEDDING_DIM, 384); | ||
const MAX_TOPK = num(process.env.MAX_TOPK, 100); | ||
|
||
// Small helper to fail fast with readable message | ||
async function ok(res: Response, label: string) { | ||
if (!res.ok) { | ||
const txt = await res.text().catch(() => ""); | ||
throw new Error(`${label} failed (${res.status}) ${txt}`.trim()); | ||
} | ||
return res; | ||
} | ||
|
||
// Adding or replacing a batch of vectors | ||
export async function vectorUpsert(items: UpsertItem[]): Promise<void> { | ||
if (!items?.length) return; // Nothing to do | ||
|
||
// Pre-validate dims to avoid server 422s | ||
for (const it of items) { | ||
if (!Array.isArray(it.vector) || it.vector.length !== EMBEDDING_DIM) { | ||
throw new Error(`vectorUpsert: vector dim ${it.vector?.length} != EMBEDDING_DIM ${EMBEDDING_DIM}`); | ||
} | ||
} | ||
|
||
const r = await fetch(`${FAISS_URL}/upsert`, { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ items }), | ||
cache: "no-store", | ||
}); | ||
await ok(r, "faiss upsert"); | ||
} | ||
|
||
// Searching topK nearest neighbors for a query vector | ||
export async function vectorSearch(query: number[], topK = 10): Promise<SearchHit[]> { | ||
if (!Array.isArray(query) || query.length === 0) return []; | ||
|
||
// Pre-validate dims to avoid server 422s | ||
if (query.length !== EMBEDDING_DIM) { | ||
throw new Error(`vectorSearch: query dim ${query.length} != EMBEDDING_DIM ${EMBEDDING_DIM}`); | ||
} | ||
|
||
const k = Math.max(1, Math.min(Number(topK) || 10, MAX_TOPK)); // Clamp 1..MAX_TOPK | ||
const r = await fetch(`${FAISS_URL}/search`, { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ query, top_k: k }), | ||
cache: "no-store", | ||
}); | ||
await ok(r, "faiss search"); | ||
return (await r.json()) as SearchHit[]; // [{ id, score, metadata }] | ||
} | ||
|
||
// Deleting by ids (dev service rebuilds index internally) | ||
export async function vectorDelete(ids: string[]): Promise<void> { | ||
if (!ids?.length) return; | ||
const r = await fetch(`${FAISS_URL}/delete`, { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ ids }), | ||
cache: "no-store", | ||
}); | ||
await ok(r, "faiss delete"); | ||
} | ||
|
||
// Quick readiness check | ||
export async function vectorPing(): Promise<boolean> { | ||
try { | ||
const r = await fetch(`${FAISS_URL}/health`, { cache: "no-store" }); | ||
return r.ok; | ||
} catch { | ||
return false; | ||
} | ||
} |
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Small FastAPI app that wraps a FAISS index | ||
# - Stores vectors in memory (ok for local dev) | ||
# - Uses cosine similarity by normalizing vectors and IndexFlatIP | ||
# - Endpoints: /health, /upsert, /delete, /search | ||
|
||
import os | ||
from fastapi import FastAPI, HTTPException | ||
from pydantic import BaseModel | ||
from typing import List, Dict, Any, Set | ||
import numpy as np | ||
import faiss | ||
|
||
app = FastAPI() | ||
|
||
DIM = int(os.getenv("EMBEDDING_DIM", "384")) | ||
MAX_TOPK = int(os.getenv("MAX_TOPK", "100")) | ||
index = faiss.IndexFlatIP(DIM) # Inner product -> cosine if vectors are L2-normalized | ||
id_map: List[str] = [] # Keeps ids parallel to FAISS rows | ||
meta_map: Dict[str, Dict[str, Any]] = {} # id -> metadata | ||
tombstones: Set[str] = set() | ||
|
||
# Payloading schemas | ||
class UpsertItem(BaseModel): | ||
id: str | ||
vector: List[float] | ||
metadata: Dict[str, Any] | None = None | ||
|
||
class UpsertPayload(BaseModel): | ||
items: List[UpsertItem] | ||
|
||
class DeletePayload(BaseModel): | ||
ids: List[str] | ||
|
||
class SearchPayload(BaseModel): | ||
query: list[float] | ||
top_k: int = 10 | ||
|
||
# Helpers | ||
def l2_normalize(v: np.ndarray) -> np.ndarray: | ||
n = np.linalg.norm(v) | ||
return v / n if n > 0 else v | ||
|
||
def normalize_matrix(X: np.ndarray) -> np.ndarray: | ||
# Normalizing each row to unit length | ||
norms = np.linalg.norm(X, axis=1, keepdims=True) | ||
norms[norms == 0] = 1.0 | ||
return X / norms | ||
|
||
# Routes | ||
@app.get("/health") | ||
def health(): | ||
return { | ||
"ok": True, | ||
"count": int(index.ntotal), | ||
"dim": DIM, | ||
"max_topk": MAX_TOPK, | ||
"tombstones": len(tombstones), | ||
} | ||
|
||
@app.post("/upsert") | ||
def upsert(p: UpsertPayload): | ||
global index, id_map | ||
if not p.items: | ||
return {"added": 0} | ||
|
||
# Building matrix of vectors, normalize rows for cosine | ||
vecs = [] | ||
for it in p.items: | ||
if len(it.vector) != DIM: | ||
raise HTTPException(status_code=422, | ||
detail=f"Vector for id '{it.id}' has dim={len(it.vector)} but EMBEDDING_DIM={DIM}") | ||
v = np.asarray(it.vector, dtype="float32") | ||
v = l2_normalize(v) | ||
dhyana6466 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
vecs.append(v) | ||
meta_map[it.id] = it.metadata or {} | ||
|
||
X = np.vstack(vecs).astype("float32") | ||
index.add(X) # Appending to FAISS | ||
id_map.extend([it.id for it in p.items]) | ||
return {"added": len(p.items)} | ||
|
||
@app.post("/delete") | ||
def delete(p: DeletePayload): | ||
dhyana6466 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Soft delete: mark IDs as tombstoned; rebuild happens in /compact | ||
if not p.ids: | ||
return {"deleted": 0} | ||
before = len(tombstones) | ||
tombstones.update(p.ids) | ||
return {"deleted": len(tombstones) - before} | ||
|
||
@app.post("/compact") | ||
def compact(): | ||
""" | ||
Physically rebuilding the index, dropping tombstoned ids | ||
Run this occassionally | ||
""" | ||
global index, id_map | ||
if index.ntotal == 0 or not tombstones: | ||
return {"compacted": 0, "remaining": int(index.ntotal)} | ||
|
||
# Reconstructing all vectors | ||
X = index.reconstruct_n(0, index.ntotal) # (N, DIM) float32 | ||
keep_idx = [i for i, _id in enumerate(id_map) if _id not in tombstones] | ||
|
||
X_keep = X[keep_idx].astype("float32") if keep_idx else np.empty((0, DIM), dtype="float32") | ||
ids_keep = [id_map[i] for i in keep_idx] | ||
|
||
new_index = faiss.IndexFlatIP(DIM) | ||
if X_keep.shape[0] > 0: | ||
new_index.add(X_keep) | ||
|
||
index = new_index | ||
id_map = ids_keep | ||
removed = len(tombstones) | ||
# Dropping metadata for removed ids | ||
for _id in list(tombstones): | ||
meta_map.pop(_id, None) | ||
tombstones.clear() | ||
return {"compacted": removed, "remaining": int(index.ntotal)} | ||
|
||
@app.post("/search") | ||
def search(p: SearchPayload): | ||
if index.ntotal == 0: | ||
return [] | ||
|
||
if len(p.query) != DIM: | ||
raise HTTPException( | ||
status_code=422, | ||
detail=f"Query vector has dim={len(p.query)} but EMBEDDING_DIM={DIM}" | ||
) | ||
|
||
q = np.asarray(p.query, dtype="float32") | ||
q = l2_normalize(q).reshape(1, -1) | ||
|
||
# Asking FAISS for extra results to account for tombstoned items we will filter out | ||
k_cap = min(MAX_TOPK, int(index.ntotal)) | ||
k_raw = int(max(1, min(max(p.top_k * 2, p.top_k + 10), k_cap))) | ||
D, I = index.search(q, k_raw) # D: scores, I: indices | ||
|
||
out = [] | ||
for score, idx in zip(D[0], I[0]): | ||
if idx < 0: | ||
continue | ||
_id = id_map[int(idx)] | ||
if _id in tombstones: | ||
continue | ||
out.append({ | ||
"id": _id, | ||
"score": float(score), | ||
"metadata": meta_map.get(_id, {}) | ||
}) | ||
if len(out) >= p.top_k: | ||
break | ||
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
fastapi==0.112.0 | ||
uvicorn==0.30.0 | ||
faiss-cpu==1.8.0 | ||
numpy==1.26.4 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.