Skip to content

Commit

Permalink
Merge pull request #24 from epsilla-cloud/hybrid_search
Browse files Browse the repository at this point in the history
Support hybrid search
  • Loading branch information
richard-epsilla authored Feb 19, 2024
2 parents 26b5094 + 23fd987 commit 18dc8f9
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 4 deletions.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "epsillajs",
"version": "0.3.3",
"version": "0.3.4",
"description": "A JS library to connect Epsilla vector database",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
6 changes: 6 additions & 0 deletions src/cloud.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import axios, { AxiosError } from 'axios';
import { DeleteRecordsConfig, EpsillaResponse, Index, PreviewConfig, QueryConfig, TableField } from './models';

import SearchEngine from './searchengine';

export interface CloudClientConfig {
projectID: string;
apiKey: string;
Expand Down Expand Up @@ -198,4 +200,8 @@ export class VectorDB {
return (err as AxiosError).response?.data as EpsillaResponse;
}
}

asSearchEngine(): SearchEngine {
return new SearchEngine(this);
}
}
29 changes: 28 additions & 1 deletion src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,31 @@ export interface EpsillaQueryResult {
export interface EpsillaResponse extends EpsillaBaseResponse {
result?: EpsillaQueryResult[] | string[];
time?: number;
}
}

export interface RetrieverConfig {
table: string;
primaryKeyField?: string;
queryIndex?: string;
queryField?: string;
queryVector?: number[] | SparseVector;
limit: number;
response?: string[];
filter?: string;
}

export interface RerankerConfig {
weights?: number[];
scaleRanges?: number[][];
k: number;
limit?: number;
}

export interface SearchEngineCandidate extends EpsillaQueryResult {
'@id': string | number;
'@distance': number;
}

export interface Reranker {
rerank(candidates: SearchEngineCandidate[][]): Promise<SearchEngineCandidate[]>;
}
304 changes: 304 additions & 0 deletions src/searchengine.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
// VectorRetriever.ts
import {
QueryPayload,
SearchEngineCandidate,
Reranker,
RetrieverConfig,
RerankerConfig,
} from './models';
import EpsillaDB from './vectordb';
import { VectorDB } from './cloud';

export class VectorRetriever {
private dbClient: EpsillaDB | VectorDB;
private table: string;
private primaryKeyField: string;
private queryIndex?: string;
private queryField?: string;
private queryVector?: any; // Adjust based on your vector type
private response?: string[];
private limit: number;
private filter: string;

constructor(
dbClient: EpsillaDB | VectorDB,
table: string,
primaryKeyField?: string,
queryIndex?: string,
queryField?: string,
queryVector?: any,
response?: string[],
limit: number = 2,
filter: string = '',
) {
this.dbClient = dbClient;
this.table = table;
this.primaryKeyField = primaryKeyField || 'ID';
this.queryIndex = queryIndex;
this.queryField = queryField;
this.queryVector = queryVector;
this.response = response;
this.limit = limit;
this.filter = filter;
}

async retrieve(query: string): Promise<SearchEngineCandidate[]> {
const queryPayload: QueryPayload = {
table: this.table,
query: query,
queryIndex: this.queryIndex,
queryField: this.queryField,
queryVector: this.queryVector,
response: this.response,
limit: this.limit,
filter: this.filter,
withDistance: true,
};
const response = await this.dbClient.query(this.table, queryPayload);
if (response instanceof Error) {
throw new Error(`Failed to retrieve data from table ${this.table}: ${response.message || 'Unknown error'}`);
} else {
// Add @id from the table to each record based on the primaryKeyField
return response.result?.map((record: any) => {
if (!(this.primaryKeyField in record)) {
throw new Error(`Primary key field ${this.primaryKeyField} not found in the response from table ${this.table}`);
}
return { ...record, '@id': record[this.primaryKeyField] };
}) || [];
}
}
}

export class RRFReRanker implements Reranker {
private weights: number[];
private k: number;
private limit?: number;

constructor(weights: number[] = [], k: number = 50, limit?: number) {
this.weights = weights;
this.k = k;
this.limit = limit;
}

async rerank(candidates: SearchEngineCandidate[][]): Promise<SearchEngineCandidate[]> {
// Initialize weights if not provided
if (!this.weights.length) {
this.weights = Array(candidates.length).fill(1);
}

// Validate weights length
if (this.weights.length !== candidates.length) {
throw new Error('The length of weights should be equal to the number of candidate lists');
}

// Calculate RRF scores for each candidate
const rrfScores: { [id: string]: { candidate: SearchEngineCandidate; score: number } } = {};

candidates.forEach((candidateList, i) => {
const weight = this.weights[i];
candidateList.forEach((candidate, rank) => {
const rrfScore = weight / (this.k + rank + 1); // +1 because rank starts at 0 in JS/TS
const id = String(candidate['@id']);
if (rrfScores[id]) {
rrfScores[id].score += rrfScore;
} else {
rrfScores[id] = { candidate, score: rrfScore };
}
});
});

// Sort candidates based on aggregated RRF score
let sortedCandidates = Object.values(rrfScores).sort((a, b) => b.score - a.score).map((item) => item.candidate);

// Apply the limit to the final list if specified
if (this.limit) {
sortedCandidates = sortedCandidates.slice(0, this.limit);
}

return sortedCandidates;
}
}

export class RelativeScoreFusionReranker implements Reranker {
private limit?: number;

constructor(limit?: number) {
this.limit = limit;
}

private normalizeDistances(candidates: SearchEngineCandidate[]): SearchEngineCandidate[] {
if (candidates.length < 2) {
return candidates.map(candidate => ({ ...candidate, normalizedScore: 1 }));
}

const distances = candidates.map(candidate => candidate['@distance'] || 1);
const maxDistance = Math.max(...distances);
const minDistance = Math.min(...distances);

if (maxDistance === minDistance) {
return candidates.map(candidate => ({ ...candidate, normalizedScore: 1 }));
}

return candidates.map(candidate => {
const normalizedScore = ((candidate['@distance'] || maxDistance) - minDistance) / (maxDistance - minDistance);
return { ...candidate, normalizedScore: 1 - normalizedScore }; // Invert so that smaller distances get higher scores
});
}

async rerank(candidatesLists: SearchEngineCandidate[][]): Promise<SearchEngineCandidate[]> {
const normalizedLists = candidatesLists.map(list => this.normalizeDistances(list));

const aggregatedScores: { [id: string]: { candidate: SearchEngineCandidate; score: number } } = {};

normalizedLists.forEach(list => {
list.forEach(({ '@id': id, normalizedScore }) => {
if (aggregatedScores[id]) {
aggregatedScores[id].score += normalizedScore as number;
} else {
aggregatedScores[id] = {
candidate: list.find(candidate => candidate['@id'] === id) as SearchEngineCandidate,
score: normalizedScore as number
};
}
});
});

let sortedCandidates = Object.values(aggregatedScores)
.sort((a, b) => b.score - a.score)
.map(item => item.candidate);

if (this.limit) {
sortedCandidates = sortedCandidates.slice(0, this.limit);
}

return sortedCandidates;
}
}

export class DistributionBasedScoreFusionReranker implements Reranker {
private scaleRanges: number[][];
private limit?: number;

constructor(scaleRanges: number[][] = [], limit?: number) {
this.scaleRanges = scaleRanges;
this.limit = limit;
}

private normalizeDistances(scaleRange: number[], candidates: SearchEngineCandidate[]): SearchEngineCandidate[] {
const [minScale, maxScale] = scaleRange;

return candidates.map(candidate => {
let normalizedScore = 0;
if (maxScale !== minScale) { // Avoid division by zero
normalizedScore = (candidate['@distance'] - minScale) / (maxScale - minScale);
normalizedScore = Math.max(0, Math.min(1, normalizedScore)); // Clamp between 0 and 1
}
return { ...candidate, normalizedScore: 1 - normalizedScore }; // Invert score so smaller distances score higher
});
}

async rerank(candidatesLists: SearchEngineCandidate[][]): Promise<SearchEngineCandidate[]> {
if (this.scaleRanges.length !== candidatesLists.length) {
throw new Error("The length of scaleRanges should be equal to the number of candidates lists.");
}

const normalizedLists = candidatesLists.map((list, index) =>
this.normalizeDistances(this.scaleRanges[index], list)
);

const aggregatedScores: { [id: string]: { candidate: SearchEngineCandidate; score: number } } = {};

normalizedLists.forEach(list => {
list.forEach(({ '@id': id, normalizedScore }) => {
if (aggregatedScores[id]) {
aggregatedScores[id].score += normalizedScore as number;
} else {
const candidate = list.find(candidate => candidate['@id'] === id) as SearchEngineCandidate;
aggregatedScores[id] = { candidate, score: normalizedScore as number };
}
});
});

let sortedCandidates = Object.values(aggregatedScores)
.sort((a, b) => b.score - a.score)
.map(item => item.candidate);

if (this.limit !== undefined) {
sortedCandidates = sortedCandidates.slice(0, this.limit);
}

return sortedCandidates;
}
}

class SearchEngine {
private dbClient: EpsillaDB | VectorDB;
private retrievers: VectorRetriever[] = [];
private reranker?: Reranker;

constructor(dbClient: EpsillaDB | VectorDB) {
this.dbClient = dbClient;
}

addRetriever(config: RetrieverConfig): SearchEngine {
this.reranker = undefined; // Reset reranker when a new retriever is added
this.retrievers.push(
new VectorRetriever(
this.dbClient,
config.table,
config.primaryKeyField,
config.queryIndex,
config.queryField,
config.queryVector,
config.response,
config.limit,
config.filter
)
);
return this;
}

setReranker(type: string, config?: RerankerConfig): SearchEngine {
switch (type) {
case "rrf":
case "reciprocal_rank_fusion":
this.reranker = new RRFReRanker(config?.weights, config?.k || 50, config?.limit);
break;
case "rsf":
case "relative_score_fusion":
this.reranker = new RelativeScoreFusionReranker(config?.limit);
break;
case "dbsf":
case "distribution_based_score_fusion":
this.reranker = new DistributionBasedScoreFusionReranker(config?.scaleRanges, config?.limit);
break;
default:
throw new Error("Invalid reranker type: " + type);
}
return this;
}

async search(query: string): Promise<SearchEngineCandidate[]> {
if (!this.retrievers.length) {
throw new Error("No retriever added to the search engine");
}
if (this.retrievers.length > 1 && !this.reranker) {
throw new Error("More than one retriever added to the search engine, but no reranker is set");
}

// Retrieve candidates concurrently
const retrievePromises = this.retrievers.map(retriever => retriever.retrieve(query));
let candidates = await Promise.all(retrievePromises);

// Rerank candidates if reranker is set
if (this.reranker) {
candidates = [await this.reranker.rerank(candidates)];
}

// Flatten the array of arrays to a single array of candidates
return candidates.flat();
}
}

export default SearchEngine;

Loading

0 comments on commit 18dc8f9

Please sign in to comment.