Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions apps/memos-local-openclaw/src/embedding/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { embedMistral } from "./providers/mistral";
import { embedLocal } from "./local";
import { modelHealth } from "../ingest/providers";

type EmbeddingInputKind = "document" | "query";

export class Embedder {
constructor(
private cfg: EmbeddingConfig | undefined,
Expand Down Expand Up @@ -43,13 +45,17 @@ export class Embedder {
if (this.provider === "cohere" && this.cfg) {
return embedCohereQuery(text, this.cfg, this.log);
}
const vecs = await this.embedBatch([text]);
const vecs = await this.embedBatch([text], "query");
return vecs[0];
}

private async embedBatch(texts: string[]): Promise<number[][]> {
private async embedBatch(
texts: string[],
inputKind: EmbeddingInputKind = "document",
): Promise<number[][]> {
const provider = this.provider;
const cfg = this.cfg;
const inputType = this.resolveInputType(inputKind);

const modelInfo = `${provider}/${cfg?.model ?? "default"}`;
try {
Expand All @@ -61,7 +67,9 @@ export class Embedder {
case "zhipu":
case "siliconflow":
case "bailian":
result = await embedOpenAI(texts, cfg!, this.log); break;
result = await embedOpenAI(texts, cfg!, this.log, inputType); break;
case "openclaw":
result = await this.embedOpenClaw(texts, inputType); break;
case "gemini":
result = await embedGemini(texts, cfg!, this.log); break;
case "cohere":
Expand All @@ -86,7 +94,13 @@ export class Embedder {
}
}

private async embedOpenClaw(texts: string[]): Promise<number[][]> {
private resolveInputType(inputKind: EmbeddingInputKind): string | undefined {
if (!this.cfg) return undefined;
if (inputKind === "query") return this.cfg.queryInputType ?? this.cfg.inputType;
return this.cfg.documentInputType ?? this.cfg.inputType;
}

private async embedOpenClaw(texts: string[], inputType?: string): Promise<number[][]> {
if (!this.openclawAPI) {
throw new Error(
"OpenClaw API not available. Ensure sharing.capabilities.hostEmbedding is enabled in config."
Expand All @@ -97,6 +111,7 @@ export class Embedder {
const response = await this.openclawAPI.embed({
texts,
model: this.cfg?.model,
inputType,
});

return response.embeddings;
Expand Down
7 changes: 6 additions & 1 deletion apps/memos-local-openclaw/src/embedding/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export async function embedOpenAI(
texts: string[],
cfg: EmbeddingConfig,
log: Logger,
inputType?: string,
): Promise<number[][]> {
const endpoint = normalizeEmbeddingEndpoint(cfg.endpoint ?? "https://api.openai.com/v1/embeddings");
const model = cfg.model ?? "text-embedding-3-small";
Expand All @@ -16,7 +17,11 @@ export async function embedOpenAI(
const resp = await fetch(endpoint, {
method: "POST",
headers,
body: JSON.stringify({ input: texts, model }),
body: JSON.stringify({
input: texts,
model,
...(inputType ? { input_type: inputType } : {}),
}),
signal: AbortSignal.timeout(cfg.timeoutMs ?? 30_000),
});

Expand Down
7 changes: 6 additions & 1 deletion apps/memos-local-openclaw/src/openclaw-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { Logger, OpenClawAPI } from "./types";
export interface OpenClawEmbedRequest {
texts: string[];
model?: string;
inputType?: string;
}

export interface OpenClawEmbedResponse {
Expand Down Expand Up @@ -98,7 +99,11 @@ export class OpenClawAPIClient implements OpenClawAPI {
const resp = await fetch(endpoint, {
method: "POST",
headers: buildHeaders(provider),
body: JSON.stringify({ input: request.texts, model }),
body: JSON.stringify({
input: request.texts,
model,
...(request.inputType ? { input_type: request.inputType } : {}),
}),
signal: AbortSignal.timeout(30_000),
});

Expand Down
16 changes: 14 additions & 2 deletions apps/memos-local-openclaw/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ export interface EmbeddingConfig extends ProviderConfig {
batchSize?: number;
dimensions?: number;
retry?: number;
inputType?: string;
queryInputType?: string;
documentInputType?: string;
}

// ─── Skill ───
Expand Down Expand Up @@ -373,8 +376,17 @@ export interface PluginContext {
}

export interface OpenClawAPI {
embed(request: { texts: string[]; model?: string }): Promise<{ embeddings: number[][]; dimensions: number }>;
complete(request: { prompt: string; maxTokens?: number; temperature?: number; model?: string }): Promise<{ text: string }>;
embed(request: {
texts: string[];
model?: string;
inputType?: string;
}): Promise<{ embeddings: number[][]; dimensions: number }>;
complete(request: {
prompt: string;
maxTokens?: number;
temperature?: number;
model?: string;
}): Promise<{ text: string }>;
}

export interface Logger {
Expand Down
24 changes: 22 additions & 2 deletions apps/memos-local-openclaw/src/viewer/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3830,14 +3830,34 @@ export class ViewerServer {
}
return vec.length;
}
const resp = await fetch(embUrl, {
const requestBody = { input: ["test embedding vector"], model: model || "text-embedding-3-small" };
let resp = await fetch(embUrl, {
method: "POST",
headers,
body: JSON.stringify({ input: ["test embedding vector"], model: model || "text-embedding-3-small" }),
body: JSON.stringify(requestBody),
signal: AbortSignal.timeout(15_000),
});
if (!resp.ok) {
const txt = await resp.text();
if (/input[_ -]?type/i.test(txt) && /required/i.test(txt)) {
resp = await fetch(embUrl, {
method: "POST",
headers,
body: JSON.stringify({ ...requestBody, input_type: "query" }),
signal: AbortSignal.timeout(15_000),
});
if (resp.ok) {
const json = await resp.json() as any;
const data = json?.data;
const vec = Array.isArray(data) && data.length > 0 ? data[0]?.embedding : undefined;
if (!Array.isArray(vec) || vec.length === 0) {
throw new Error(
`API returned empty embedding vector (got ${JSON.stringify(vec)?.slice(0, 100)})`,
);
}
return vec.length;
}
}
throw new Error(`${resp.status}: ${txt}`);
}
const json = await resp.json() as any;
Expand Down
69 changes: 69 additions & 0 deletions apps/memos-local-openclaw/tests/embedding-input-type.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { afterEach, describe, expect, it, vi } from "vitest";

import { Embedder } from "../src/embedding";
import type { EmbeddingConfig, Logger } from "../src/types";

const noopLog: Logger = {
debug: () => {},
info: () => {},
warn: () => {},
error: () => {},
};

function mockEmbeddingFetch() {
const bodies: Array<Record<string, unknown>> = [];
const fetchMock = vi.fn(async (_url: string, init?: RequestInit) => {
bodies.push(JSON.parse(String(init?.body)));
return {
ok: true,
json: async () => ({ data: [{ embedding: [0.1, 0.2, 0.3] }] }),
} as Response;
});
vi.stubGlobal("fetch", fetchMock);
return bodies;
}

function openAiConfig(overrides: Partial<EmbeddingConfig> = {}): EmbeddingConfig {
return {
provider: "openai_compatible",
endpoint: "https://embeddings.example.test/v1",
apiKey: "test-key",
model: "asymmetric-model",
...overrides,
};
}

afterEach(() => {
vi.unstubAllGlobals();
});

describe("embedding input_type routing", () => {
it("uses documentInputType for document embeddings and queryInputType for query embeddings", async () => {
const bodies = mockEmbeddingFetch();
const embedder = new Embedder(
openAiConfig({
inputType: "passage",
documentInputType: "document",
queryInputType: "query",
}),
noopLog,
);

await embedder.embed(["stored memory"]);
await embedder.embedQuery("search terms");

expect(bodies[0]).toMatchObject({ input_type: "document" });
expect(bodies[1]).toMatchObject({ input_type: "query" });
});

it("falls back to inputType when a specific query or document input type is absent", async () => {
const bodies = mockEmbeddingFetch();
const embedder = new Embedder(openAiConfig({ inputType: "passage" }), noopLog);

await embedder.embed(["stored memory"]);
await embedder.embedQuery("search terms");

expect(bodies[0]).toMatchObject({ input_type: "passage" });
expect(bodies[1]).toMatchObject({ input_type: "passage" });
});
});