Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions app/api/health/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import { NextResponse } from "next/server";
export async function GET() {
return NextResponse.json({ ok: true });
}
38 changes: 38 additions & 0 deletions app/api/search/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// app/api/search/route.ts
import { NextRequest, NextResponse } from "next/server";
import { vectorSearch } from "@/app/lib/vectorIndex";
import { embedById } from "@/app/lib/embeddings";

type Body = { imageId?: string; vector?: unknown; topK?: unknown };

function isNumberArray(x: unknown): x is number[] {
return Array.isArray(x) && x.every((n) => typeof n === "number");
}

export async function POST(req: NextRequest) {
try {
const body: Body = await req.json();

// topK: default 10, clamp to [1, 100]
let topK = 10;
if (typeof body.topK === "number") {
topK = Math.min(Math.max(Math.floor(body.topK), 1), 100);
}

// Building the query vector
let query: number[];
if (typeof body.imageId === "string" && body.imageId.length > 0) {
query = await embedById(body.imageId); // normalized inside
} else if (isNumberArray(body.vector)) {
query = body.vector;
} else {
return NextResponse.json({ error: "imageId or vector required" }, { status: 400 });
}

const results = await vectorSearch(query, topK);
return NextResponse.json({ results });
} catch (err: unknown) {
const message = err instanceof Error ? err.message : "search_failed";
return NextResponse.json({ error: message }, { status: 500 });
}
}
23 changes: 23 additions & 0 deletions app/lib/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Small helper around Universal Model Adapter
// Ensures the vector is L2-normalized (cosine works correctly)

import { getImageEmbeddings } from "@/app/lib/modelClient";

/**
* get embedding vector for an image id (file path, url, or blob)
* ensures result is 1-D array of floats, L2-normalized
*/
export async function embedById(image: Blob | string): Promise<number[]> {
// Run through modelClient
const raw = await getImageEmbeddings(image);

// Flattern nested arrays if pipeline returned [ [ [ ... ] ] ]
let v: number[] = [];
if (Array.isArray(raw)) {
v = raw.flat(Infinity) as number[];
} else {
throw new Error("embedding result is not array");
}

return v;
}
62 changes: 33 additions & 29 deletions app/lib/modelClient.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
// // lib/modelClient.ts
// "use client";
// app/lib/modelClient.ts

// import { pipeline } from "@xenova/transformers";
export type EmbeddingRequest = { imageId?: string; dataUrl?: string };
export type EmbeddingResponse = { vector: number[] };

// /** Singleton references to loaded pipelines */
// let imageEmbedder: any = null;
function isEmbeddingResponse(x: unknown): x is EmbeddingResponse {
const r = x as { vector?: unknown };
return Array.isArray(r?.vector) && r.vector.every((n) => typeof n === "number");
}

// /**
// * loading a CLIP embedding pipeline for image search (example).
// * this will make it easy to use other models
// */
// export async function loadEmbeddingPipeline() {
// if (!imageEmbedder) {
// // Using CLIP for image embeddings, just as an example
// imageEmbedder = await pipeline(
// "feature-extraction",
// "Xenova/clip-vit-base-patch32"
// );
// }
// return imageEmbedder;
// }
/** Getying an embedding for an image by its ID (e.g., S3 key) via your embed API. */
export async function getEmbeddingForImageId(id: string): Promise<number[]> {
const res = await fetch("/api/model/embed", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ imageId: id } as EmbeddingRequest),
cache: "no-store",
});

// /**
// * Extract embeddings from an image.
// * Returns a vector (array of floats) we can then compare with our dataset.
// */
// export async function getImageEmbeddings(image: Blob | string) {
// const embedder = await loadEmbeddingPipeline();
// // The pipeline returns a nested array. We'll flatten or keep it nested as needed.
// const result = await embedder(image);
// return result;
// }
const json: unknown = await res.json();
if (!isEmbeddingResponse(json)) throw new Error("Bad embed response (imageId)");
return json.vector;
}

/** Getting an embedding for raw image data (e.g., data URL) via your embed API. */
export async function getEmbeddingForImageData(dataUrl: string): Promise<number[]> {
const res = await fetch("/api/model/embed", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ dataUrl } as EmbeddingRequest),
cache: "no-store",
});

const json: unknown = await res.json();
if (!isEmbeddingResponse(json)) throw new Error("Bad embed response (dataUrl)");
return json.vector;
}
96 changes: 96 additions & 0 deletions app/lib/vectorIndex.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// FAISS-only client used by Next.js API routes
// Endpoints expected on the FAISS service: /health, /upsert, /delete, /search

export type SearchHit = {
id: string; // image id (e.g., S3 key)
score: number; // similarity score (cosine via inner product)
metadata?: Record<string, unknown>;
};

type UpsertItem = {
id: string;
vector: number[]; // raw embedding
metadata?: Record<string, unknown>;
};

// Config
function num(env: string | undefined, fallback: number): number {
const n = Number(env);
return Number.isFinite(n) && n > 0 ? n : fallback;
}

// Reading from env
const FAISS_URL = (process.env.FAISS_URL || "http://127.0.0.1:8000").replace(/\/+$/, "");
const EMBEDDING_DIM = num(process.env.EMBEDDING_DIM, 384);
const MAX_TOPK = num(process.env.MAX_TOPK, 100);

// Small helper to fail fast with readable message
async function ok(res: Response, label: string) {
if (!res.ok) {
const txt = await res.text().catch(() => "");
throw new Error(`${label} failed (${res.status}) ${txt}`.trim());
}
return res;
}

// Adding or replacing a batch of vectors
export async function vectorUpsert(items: UpsertItem[]): Promise<void> {
if (!items?.length) return; // Nothing to do

// Pre-validate dims to avoid server 422s
for (const it of items) {
if (!Array.isArray(it.vector) || it.vector.length !== EMBEDDING_DIM) {
throw new Error(`vectorUpsert: vector dim ${it.vector?.length} != EMBEDDING_DIM ${EMBEDDING_DIM}`);
}
}

const r = await fetch(`${FAISS_URL}/upsert`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ items }),
cache: "no-store",
});
await ok(r, "faiss upsert");
}

// Searching topK nearest neighbors for a query vector
export async function vectorSearch(query: number[], topK = 10): Promise<SearchHit[]> {
if (!Array.isArray(query) || query.length === 0) return [];

// Pre-validate dims to avoid server 422s
if (query.length !== EMBEDDING_DIM) {
throw new Error(`vectorSearch: query dim ${query.length} != EMBEDDING_DIM ${EMBEDDING_DIM}`);
}

const k = Math.max(1, Math.min(Number(topK) || 10, MAX_TOPK)); // Clamp 1..MAX_TOPK
const r = await fetch(`${FAISS_URL}/search`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ query, top_k: k }),
cache: "no-store",
});
await ok(r, "faiss search");
return (await r.json()) as SearchHit[]; // [{ id, score, metadata }]
}

// Deleting by ids (dev service rebuilds index internally)
export async function vectorDelete(ids: string[]): Promise<void> {
if (!ids?.length) return;
const r = await fetch(`${FAISS_URL}/delete`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ ids }),
cache: "no-store",
});
await ok(r, "faiss delete");
}

// Quick readiness check
export async function vectorPing(): Promise<boolean> {
try {
const r = await fetch(`${FAISS_URL}/health`, { cache: "no-store" });
return r.ok;
} catch {
return false;
}
}
Binary file added faiss_service/__pycache__/app.cpython-311.pyc
Binary file not shown.
154 changes: 154 additions & 0 deletions faiss_service/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Small FastAPI app that wraps a FAISS index
# - Stores vectors in memory (ok for local dev)
# - Uses cosine similarity by normalizing vectors and IndexFlatIP
# - Endpoints: /health, /upsert, /delete, /search

import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Set
import numpy as np
import faiss

app = FastAPI()

DIM = int(os.getenv("EMBEDDING_DIM", "384"))
MAX_TOPK = int(os.getenv("MAX_TOPK", "100"))
index = faiss.IndexFlatIP(DIM) # Inner product -> cosine if vectors are L2-normalized
id_map: List[str] = [] # Keeps ids parallel to FAISS rows
meta_map: Dict[str, Dict[str, Any]] = {} # id -> metadata
tombstones: Set[str] = set()

# Payloading schemas
class UpsertItem(BaseModel):
id: str
vector: List[float]
metadata: Dict[str, Any] | None = None

class UpsertPayload(BaseModel):
items: List[UpsertItem]

class DeletePayload(BaseModel):
ids: List[str]

class SearchPayload(BaseModel):
query: list[float]
top_k: int = 10

# Helpers
def l2_normalize(v: np.ndarray) -> np.ndarray:
n = np.linalg.norm(v)
return v / n if n > 0 else v

def normalize_matrix(X: np.ndarray) -> np.ndarray:
# Normalizing each row to unit length
norms = np.linalg.norm(X, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return X / norms

# Routes
@app.get("/health")
def health():
return {
"ok": True,
"count": int(index.ntotal),
"dim": DIM,
"max_topk": MAX_TOPK,
"tombstones": len(tombstones),
}

@app.post("/upsert")
def upsert(p: UpsertPayload):
global index, id_map
if not p.items:
return {"added": 0}

# Building matrix of vectors, normalize rows for cosine
vecs = []
for it in p.items:
if len(it.vector) != DIM:
raise HTTPException(status_code=422,
detail=f"Vector for id '{it.id}' has dim={len(it.vector)} but EMBEDDING_DIM={DIM}")
v = np.asarray(it.vector, dtype="float32")
v = l2_normalize(v)
vecs.append(v)
meta_map[it.id] = it.metadata or {}

X = np.vstack(vecs).astype("float32")
index.add(X) # Appending to FAISS
id_map.extend([it.id for it in p.items])
return {"added": len(p.items)}

@app.post("/delete")
def delete(p: DeletePayload):
# Soft delete: mark IDs as tombstoned; rebuild happens in /compact
if not p.ids:
return {"deleted": 0}
before = len(tombstones)
tombstones.update(p.ids)
return {"deleted": len(tombstones) - before}

@app.post("/compact")
def compact():
"""
Physically rebuilding the index, dropping tombstoned ids
Run this occassionally
"""
global index, id_map
if index.ntotal == 0 or not tombstones:
return {"compacted": 0, "remaining": int(index.ntotal)}

# Reconstructing all vectors
X = index.reconstruct_n(0, index.ntotal) # (N, DIM) float32
keep_idx = [i for i, _id in enumerate(id_map) if _id not in tombstones]

X_keep = X[keep_idx].astype("float32") if keep_idx else np.empty((0, DIM), dtype="float32")
ids_keep = [id_map[i] for i in keep_idx]

new_index = faiss.IndexFlatIP(DIM)
if X_keep.shape[0] > 0:
new_index.add(X_keep)

index = new_index
id_map = ids_keep
removed = len(tombstones)
# Dropping metadata for removed ids
for _id in list(tombstones):
meta_map.pop(_id, None)
tombstones.clear()
return {"compacted": removed, "remaining": int(index.ntotal)}

@app.post("/search")
def search(p: SearchPayload):
if index.ntotal == 0:
return []

if len(p.query) != DIM:
raise HTTPException(
status_code=422,
detail=f"Query vector has dim={len(p.query)} but EMBEDDING_DIM={DIM}"
)

q = np.asarray(p.query, dtype="float32")
q = l2_normalize(q).reshape(1, -1)

# Asking FAISS for extra results to account for tombstoned items we will filter out
k_cap = min(MAX_TOPK, int(index.ntotal))
k_raw = int(max(1, min(max(p.top_k * 2, p.top_k + 10), k_cap)))
D, I = index.search(q, k_raw) # D: scores, I: indices

out = []
for score, idx in zip(D[0], I[0]):
if idx < 0:
continue
_id = id_map[int(idx)]
if _id in tombstones:
continue
out.append({
"id": _id,
"score": float(score),
"metadata": meta_map.get(_id, {})
})
if len(out) >= p.top_k:
break
return out
4 changes: 4 additions & 0 deletions faiss_service/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fastapi==0.112.0
uvicorn==0.30.0
faiss-cpu==1.8.0
numpy==1.26.4