From 6965a2cc9d78e680ea2a46a1266a97a96a40b25e Mon Sep 17 00:00:00 2001 From: Jake Date: Sat, 7 Feb 2026 10:09:32 +1300 Subject: [PATCH] feat(memory): native Voyage AI support (#7078) * feat(memory): add native Voyage AI embedding support with batching Cherry-picked from PR #2519, resolved conflict in memory-search.ts (hasRemote -> hasRemoteConfig rename + added voyage provider) * fix(memory): optimize voyage batch memory usage with streaming and deduplicate code Cherry-picked from PR #2519. Fixed lint error: changed this.runWithConcurrency to use imported runWithConcurrency function after extraction to internal.ts --- src/agents/memory-search.ts | 15 +- src/config/schema.ts | 3 +- src/config/types.tools.ts | 4 +- src/config/zod-schema.agent-runtime.ts | 12 +- src/memory/batch-voyage.test.ts | 170 ++++++++++++ src/memory/batch-voyage.ts | 363 +++++++++++++++++++++++++ src/memory/embeddings-voyage.test.ts | 100 +++++++ src/memory/embeddings-voyage.ts | 86 ++++++ src/memory/embeddings.ts | 22 +- src/memory/internal.ts | 30 ++ src/memory/manager.ts | 132 ++++++--- 11 files changed, 879 insertions(+), 58 deletions(-) create mode 100644 src/memory/batch-voyage.test.ts create mode 100644 src/memory/batch-voyage.ts create mode 100644 src/memory/embeddings-voyage.test.ts create mode 100644 src/memory/embeddings-voyage.ts diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index 658771a11b..5394b640d0 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -9,7 +9,7 @@ export type ResolvedMemorySearchConfig = { enabled: boolean; sources: Array<"memory" | "sessions">; extraPaths: string[]; - provider: "openai" | "local" | "gemini" | "auto"; + provider: "openai" | "local" | "gemini" | "voyage" | "auto"; remote?: { baseUrl?: string; apiKey?: string; @@ -25,7 +25,7 @@ export type ResolvedMemorySearchConfig = { experimental: { sessionMemory: boolean; }; - fallback: "openai" | "gemini" | "local" | "none"; + fallback: "openai" | "gemini" | "local" | "voyage" | "none"; model: string; local: { modelPath?: string; @@ -72,6 +72,7 @@ export type ResolvedMemorySearchConfig = { const DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; const DEFAULT_GEMINI_MODEL = "gemini-embedding-001"; +const DEFAULT_VOYAGE_MODEL = "voyage-4-large"; const DEFAULT_CHUNK_TOKENS = 400; const DEFAULT_CHUNK_OVERLAP = 80; const DEFAULT_WATCH_DEBOUNCE_MS = 1500; @@ -136,7 +137,11 @@ function mergeConfig( defaultRemote?.headers, ); const includeRemote = - hasRemoteConfig || provider === "openai" || provider === "gemini" || provider === "auto"; + hasRemoteConfig || + provider === "openai" || + provider === "gemini" || + provider === "voyage" || + provider === "auto"; const batch = { enabled: overrideRemote?.batch?.enabled ?? defaultRemote?.batch?.enabled ?? true, wait: overrideRemote?.batch?.wait ?? defaultRemote?.batch?.wait ?? true, @@ -163,7 +168,9 @@ function mergeConfig( ? DEFAULT_GEMINI_MODEL : provider === "openai" ? DEFAULT_OPENAI_MODEL - : undefined; + : provider === "voyage" + ? DEFAULT_VOYAGE_MODEL + : undefined; const model = overrides?.model ?? defaults?.model ?? modelDefault ?? ""; const local = { modelPath: overrides?.local?.modelPath ?? defaults?.local?.modelPath, diff --git a/src/config/schema.ts b/src/config/schema.ts index 175265ac16..a9c177c824 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -542,7 +542,8 @@ const FIELD_HELP: Record = { "Extra paths to include in memory search (directories or .md files; relative paths resolved from workspace).", "agents.defaults.memorySearch.experimental.sessionMemory": "Enable experimental session transcript indexing for memory search (default: false).", - "agents.defaults.memorySearch.provider": 'Embedding provider ("openai", "gemini", or "local").', + "agents.defaults.memorySearch.provider": + 'Embedding provider ("openai", "gemini", "voyage", or "local").', "agents.defaults.memorySearch.remote.baseUrl": "Custom base URL for remote embeddings (OpenAI-compatible proxies or Gemini overrides).", "agents.defaults.memorySearch.remote.apiKey": "Custom API key for the remote embedding provider.", diff --git a/src/config/types.tools.ts b/src/config/types.tools.ts index b080324277..36700b6ce0 100644 --- a/src/config/types.tools.ts +++ b/src/config/types.tools.ts @@ -234,7 +234,7 @@ export type MemorySearchConfig = { sessionMemory?: boolean; }; /** Embedding provider mode. */ - provider?: "openai" | "gemini" | "local"; + provider?: "openai" | "gemini" | "local" | "voyage"; remote?: { baseUrl?: string; apiKey?: string; @@ -253,7 +253,7 @@ export type MemorySearchConfig = { }; }; /** Fallback behavior when embeddings fail. */ - fallback?: "openai" | "gemini" | "local" | "none"; + fallback?: "openai" | "gemini" | "local" | "voyage" | "none"; /** Embedding model id (remote) or alias (local). */ model?: string; /** Local embedding settings (node-llama-cpp). */ diff --git a/src/config/zod-schema.agent-runtime.ts b/src/config/zod-schema.agent-runtime.ts index c2e792f32f..582853ff37 100644 --- a/src/config/zod-schema.agent-runtime.ts +++ b/src/config/zod-schema.agent-runtime.ts @@ -318,7 +318,9 @@ export const MemorySearchSchema = z }) .strict() .optional(), - provider: z.union([z.literal("openai"), z.literal("local"), z.literal("gemini")]).optional(), + provider: z + .union([z.literal("openai"), z.literal("local"), z.literal("gemini"), z.literal("voyage")]) + .optional(), remote: z .object({ baseUrl: z.string().optional(), @@ -338,7 +340,13 @@ export const MemorySearchSchema = z .strict() .optional(), fallback: z - .union([z.literal("openai"), z.literal("gemini"), z.literal("local"), z.literal("none")]) + .union([ + z.literal("openai"), + z.literal("gemini"), + z.literal("local"), + z.literal("voyage"), + z.literal("none"), + ]) .optional(), model: z.string().optional(), local: z diff --git a/src/memory/batch-voyage.test.ts b/src/memory/batch-voyage.test.ts new file mode 100644 index 0000000000..e0e757f19e --- /dev/null +++ b/src/memory/batch-voyage.test.ts @@ -0,0 +1,170 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { ReadableStream } from "node:stream/web"; +import type { VoyageBatchOutputLine, VoyageBatchRequest } from "./batch-voyage.js"; +import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; + +// Mock internal.js if needed, but runWithConcurrency is simple enough to keep real. +// We DO need to mock retryAsync to avoid actual delays/retries logic complicating tests +vi.mock("../infra/retry.js", () => ({ + retryAsync: async (fn: () => Promise) => fn(), +})); + +describe("runVoyageEmbeddingBatches", () => { + afterEach(() => { + vi.resetAllMocks(); + vi.unstubAllGlobals(); + }); + + const mockClient: VoyageEmbeddingClient = { + baseUrl: "https://api.voyageai.com/v1", + headers: { Authorization: "Bearer test-key" }, + model: "voyage-4-large", + }; + + const mockRequests: VoyageBatchRequest[] = [ + { custom_id: "req-1", body: { input: "text1" } }, + { custom_id: "req-2", body: { input: "text2" } }, + ]; + + it("successfully submits batch, waits, and streams results", async () => { + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + // Sequence of fetch calls: + // 1. Upload file + fetchMock.mockResolvedValueOnce({ + ok: true, + json: async () => ({ id: "file-123" }), + }); + + // 2. Create batch + fetchMock.mockResolvedValueOnce({ + ok: true, + json: async () => ({ id: "batch-abc", status: "pending" }), + }); + + // 3. Poll status (pending) - Optional depending on wait loop, let's say it finishes immediately for this test + // Actually the code does: initial check (if completed) -> wait loop. + // If create returns "pending", it enters waitForVoyageBatch. + // waitForVoyageBatch fetches status. + + // 3. Poll status (completed) + fetchMock.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + id: "batch-abc", + status: "completed", + output_file_id: "file-out-999", + }), + }); + + // 4. Download content (Streaming) + const outputLines: VoyageBatchOutputLine[] = [ + { + custom_id: "req-1", + response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } }, + }, + { + custom_id: "req-2", + response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } }, + }, + ]; + + // Create a stream that emits the NDJSON lines + const stream = new ReadableStream({ + start(controller) { + const text = outputLines.map((l) => JSON.stringify(l)).join("\n"); + controller.enqueue(new TextEncoder().encode(text)); + controller.close(); + }, + }); + + fetchMock.mockResolvedValueOnce({ + ok: true, + body: stream, + }); + + const { runVoyageEmbeddingBatches } = await import("./batch-voyage.js"); + + const results = await runVoyageEmbeddingBatches({ + client: mockClient, + agentId: "agent-1", + requests: mockRequests, + wait: true, + pollIntervalMs: 1, // fast poll + timeoutMs: 1000, + concurrency: 1, + }); + + expect(results.size).toBe(2); + expect(results.get("req-1")).toEqual([0.1, 0.1]); + expect(results.get("req-2")).toEqual([0.2, 0.2]); + + // Verify calls + expect(fetchMock).toHaveBeenCalledTimes(4); + + // Verify File Upload + expect(fetchMock.mock.calls[0][0]).toContain("/files"); + const uploadBody = fetchMock.mock.calls[0][1].body as FormData; + expect(uploadBody).toBeInstanceOf(FormData); + expect(uploadBody.get("purpose")).toBe("batch"); + + // Verify Batch Create + expect(fetchMock.mock.calls[1][0]).toContain("/batches"); + const createBody = JSON.parse(fetchMock.mock.calls[1][1].body); + expect(createBody.input_file_id).toBe("file-123"); + expect(createBody.completion_window).toBe("12h"); + + // Verify Content Fetch + expect(fetchMock.mock.calls[3][0]).toContain("/files/file-out-999/content"); + }); + + it("handles empty lines and stream chunks correctly", async () => { + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + // 1. Upload + fetchMock.mockResolvedValueOnce({ ok: true, json: async () => ({ id: "f1" }) }); + // 2. Create (completed immediately) + fetchMock.mockResolvedValueOnce({ + ok: true, + json: async () => ({ id: "b1", status: "completed", output_file_id: "out1" }), + }); + // 3. Download Content (Streaming with chunks and newlines) + const stream = new ReadableStream({ + start(controller) { + const line1 = JSON.stringify({ + custom_id: "req-1", + response: { body: { data: [{ embedding: [1] }] } }, + }); + const line2 = JSON.stringify({ + custom_id: "req-2", + response: { body: { data: [{ embedding: [2] }] } }, + }); + + // Split across chunks + controller.enqueue(new TextEncoder().encode(line1 + "\n")); + controller.enqueue(new TextEncoder().encode("\n")); // empty line + controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF + controller.close(); + }, + }); + + fetchMock.mockResolvedValueOnce({ ok: true, body: stream }); + + const { runVoyageEmbeddingBatches } = await import("./batch-voyage.js"); + + const results = await runVoyageEmbeddingBatches({ + client: mockClient, + agentId: "a1", + requests: mockRequests, + wait: true, + pollIntervalMs: 1, + timeoutMs: 1000, + concurrency: 1, + }); + + expect(results.get("req-1")).toEqual([1]); + expect(results.get("req-2")).toEqual([2]); + }); +}); diff --git a/src/memory/batch-voyage.ts b/src/memory/batch-voyage.ts new file mode 100644 index 0000000000..5e882738cc --- /dev/null +++ b/src/memory/batch-voyage.ts @@ -0,0 +1,363 @@ +import { createInterface } from "node:readline"; +import { Readable } from "node:stream"; + +import { retryAsync } from "../infra/retry.js"; +import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; +import { hashText, runWithConcurrency } from "./internal.js"; + +/** + * Voyage Batch API Input Line format. + * See: https://docs.voyageai.com/docs/batch-inference + */ +export type VoyageBatchRequest = { + custom_id: string; + body: { + input: string | string[]; + }; +}; + +export type VoyageBatchStatus = { + id?: string; + status?: string; + output_file_id?: string | null; + error_file_id?: string | null; +}; + +export type VoyageBatchOutputLine = { + custom_id?: string; + response?: { + status_code?: number; + body?: { + data?: Array<{ embedding?: number[]; index?: number }>; + error?: { message?: string }; + }; + }; + error?: { message?: string }; +}; + +export const VOYAGE_BATCH_ENDPOINT = "/v1/embeddings"; +const VOYAGE_BATCH_COMPLETION_WINDOW = "12h"; +const VOYAGE_BATCH_MAX_REQUESTS = 50000; + +function getVoyageBaseUrl(client: VoyageEmbeddingClient): string { + return client.baseUrl?.replace(/\/$/, "") ?? ""; +} + +function getVoyageHeaders( + client: VoyageEmbeddingClient, + params: { json: boolean }, +): Record { + const headers = client.headers ? { ...client.headers } : {}; + if (params.json) { + if (!headers["Content-Type"] && !headers["content-type"]) { + headers["Content-Type"] = "application/json"; + } + } else { + delete headers["Content-Type"]; + delete headers["content-type"]; + } + return headers; +} + +function splitVoyageBatchRequests(requests: VoyageBatchRequest[]): VoyageBatchRequest[][] { + if (requests.length <= VOYAGE_BATCH_MAX_REQUESTS) return [requests]; + const groups: VoyageBatchRequest[][] = []; + for (let i = 0; i < requests.length; i += VOYAGE_BATCH_MAX_REQUESTS) { + groups.push(requests.slice(i, i + VOYAGE_BATCH_MAX_REQUESTS)); + } + return groups; +} + +async function submitVoyageBatch(params: { + client: VoyageEmbeddingClient; + requests: VoyageBatchRequest[]; + agentId: string; +}): Promise { + const baseUrl = getVoyageBaseUrl(params.client); + const jsonl = params.requests.map((request) => JSON.stringify(request)).join("\n"); + const form = new FormData(); + form.append("purpose", "batch"); + form.append( + "file", + new Blob([jsonl], { type: "application/jsonl" }), + `memory-embeddings.${hashText(String(Date.now()))}.jsonl`, + ); + + // 1. Upload file using Voyage Files API + const fileRes = await fetch(`${baseUrl}/files`, { + method: "POST", + headers: getVoyageHeaders(params.client, { json: false }), + body: form, + }); + if (!fileRes.ok) { + const text = await fileRes.text(); + throw new Error(`voyage batch file upload failed: ${fileRes.status} ${text}`); + } + const filePayload = (await fileRes.json()) as { id?: string }; + if (!filePayload.id) { + throw new Error("voyage batch file upload failed: missing file id"); + } + + // 2. Create batch job using Voyage Batches API + const batchRes = await retryAsync( + async () => { + const res = await fetch(`${baseUrl}/batches`, { + method: "POST", + headers: getVoyageHeaders(params.client, { json: true }), + body: JSON.stringify({ + input_file_id: filePayload.id, + endpoint: VOYAGE_BATCH_ENDPOINT, + completion_window: VOYAGE_BATCH_COMPLETION_WINDOW, + request_params: { + model: params.client.model, + }, + metadata: { + source: "clawdbot-memory", + agent: params.agentId, + }, + }), + }); + if (!res.ok) { + const text = await res.text(); + const err = new Error(`voyage batch create failed: ${res.status} ${text}`) as Error & { + status?: number; + }; + err.status = res.status; + throw err; + } + return res; + }, + { + attempts: 3, + minDelayMs: 300, + maxDelayMs: 2000, + jitter: 0.2, + shouldRetry: (err) => { + const status = (err as { status?: number }).status; + return status === 429 || (typeof status === "number" && status >= 500); + }, + }, + ); + return (await batchRes.json()) as VoyageBatchStatus; +} + +async function fetchVoyageBatchStatus(params: { + client: VoyageEmbeddingClient; + batchId: string; +}): Promise { + const baseUrl = getVoyageBaseUrl(params.client); + const res = await fetch(`${baseUrl}/batches/${params.batchId}`, { + headers: getVoyageHeaders(params.client, { json: true }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`voyage batch status failed: ${res.status} ${text}`); + } + return (await res.json()) as VoyageBatchStatus; +} + +async function readVoyageBatchError(params: { + client: VoyageEmbeddingClient; + errorFileId: string; +}): Promise { + try { + const baseUrl = getVoyageBaseUrl(params.client); + const res = await fetch(`${baseUrl}/files/${params.errorFileId}/content`, { + headers: getVoyageHeaders(params.client, { json: true }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`voyage batch error file content failed: ${res.status} ${text}`); + } + const text = await res.text(); + if (!text.trim()) return undefined; + const lines = text + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as VoyageBatchOutputLine); + const first = lines.find((line) => line.error?.message || line.response?.body?.error); + const message = + first?.error?.message ?? + (typeof first?.response?.body?.error?.message === "string" + ? first?.response?.body?.error?.message + : undefined); + return message; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return message ? `error file unavailable: ${message}` : undefined; + } +} + +async function waitForVoyageBatch(params: { + client: VoyageEmbeddingClient; + batchId: string; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + debug?: (message: string, data?: Record) => void; + initial?: VoyageBatchStatus; +}): Promise<{ outputFileId: string; errorFileId?: string }> { + const start = Date.now(); + let current: VoyageBatchStatus | undefined = params.initial; + while (true) { + const status = + current ?? + (await fetchVoyageBatchStatus({ + client: params.client, + batchId: params.batchId, + })); + const state = status.status ?? "unknown"; + if (state === "completed") { + if (!status.output_file_id) { + throw new Error(`voyage batch ${params.batchId} completed without output file`); + } + return { + outputFileId: status.output_file_id, + errorFileId: status.error_file_id ?? undefined, + }; + } + if (["failed", "expired", "cancelled", "canceled"].includes(state)) { + const detail = status.error_file_id + ? await readVoyageBatchError({ client: params.client, errorFileId: status.error_file_id }) + : undefined; + const suffix = detail ? `: ${detail}` : ""; + throw new Error(`voyage batch ${params.batchId} ${state}${suffix}`); + } + if (!params.wait) { + throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`); + } + if (Date.now() - start > params.timeoutMs) { + throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`); + } + params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); + await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); + current = undefined; + } +} + +export async function runVoyageEmbeddingBatches(params: { + client: VoyageEmbeddingClient; + agentId: string; + requests: VoyageBatchRequest[]; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + concurrency: number; + debug?: (message: string, data?: Record) => void; +}): Promise> { + if (params.requests.length === 0) return new Map(); + const groups = splitVoyageBatchRequests(params.requests); + const byCustomId = new Map(); + + const tasks = groups.map((group, groupIndex) => async () => { + const batchInfo = await submitVoyageBatch({ + client: params.client, + requests: group, + agentId: params.agentId, + }); + if (!batchInfo.id) { + throw new Error("voyage batch create failed: missing batch id"); + } + + params.debug?.("memory embeddings: voyage batch created", { + batchId: batchInfo.id, + status: batchInfo.status, + group: groupIndex + 1, + groups: groups.length, + requests: group.length, + }); + + if (!params.wait && batchInfo.status !== "completed") { + throw new Error( + `voyage batch ${batchInfo.id} submitted; enable remote.batch.wait to await completion`, + ); + } + + const completed = + batchInfo.status === "completed" + ? { + outputFileId: batchInfo.output_file_id ?? "", + errorFileId: batchInfo.error_file_id ?? undefined, + } + : await waitForVoyageBatch({ + client: params.client, + batchId: batchInfo.id, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + debug: params.debug, + initial: batchInfo, + }); + if (!completed.outputFileId) { + throw new Error(`voyage batch ${batchInfo.id} completed without output file`); + } + + const baseUrl = getVoyageBaseUrl(params.client); + const contentRes = await fetch(`${baseUrl}/files/${completed.outputFileId}/content`, { + headers: getVoyageHeaders(params.client, { json: true }), + }); + if (!contentRes.ok) { + const text = await contentRes.text(); + throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`); + } + + const errors: string[] = []; + const remaining = new Set(group.map((request) => request.custom_id)); + + if (contentRes.body) { + const reader = createInterface({ + input: Readable.fromWeb(contentRes.body as any), + terminal: false, + }); + + for await (const rawLine of reader) { + if (!rawLine.trim()) continue; + const line = JSON.parse(rawLine) as VoyageBatchOutputLine; + const customId = line.custom_id; + if (!customId) continue; + remaining.delete(customId); + if (line.error?.message) { + errors.push(`${customId}: ${line.error.message}`); + continue; + } + const response = line.response; + const statusCode = response?.status_code ?? 0; + if (statusCode >= 400) { + const message = + response?.body?.error?.message ?? + (typeof response?.body === "string" ? response.body : undefined) ?? + "unknown error"; + errors.push(`${customId}: ${message}`); + continue; + } + const data = response?.body?.data ?? []; + const embedding = data[0]?.embedding ?? []; + if (embedding.length === 0) { + errors.push(`${customId}: empty embedding`); + continue; + } + byCustomId.set(customId, embedding); + } + } + + if (errors.length > 0) { + throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`); + } + if (remaining.size > 0) { + throw new Error(`voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`); + } + }); + + params.debug?.("memory embeddings: voyage batch submit", { + requests: params.requests.length, + groups: groups.length, + wait: params.wait, + concurrency: params.concurrency, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + }); + + await runWithConcurrency(tasks, params.concurrency); + return byCustomId; +} diff --git a/src/memory/embeddings-voyage.test.ts b/src/memory/embeddings-voyage.test.ts new file mode 100644 index 0000000000..d9cc1d5419 --- /dev/null +++ b/src/memory/embeddings-voyage.test.ts @@ -0,0 +1,100 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +vi.mock("../agents/model-auth.js", () => ({ + resolveApiKeyForProvider: vi.fn(), + requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { + if (auth?.apiKey) return auth.apiKey; + throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth?.mode}).`); + }, +})); + +const createFetchMock = () => + vi.fn(async () => ({ + ok: true, + status: 200, + json: async () => ({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), + })) as unknown as typeof fetch; + +describe("voyage embedding provider", () => { + afterEach(() => { + vi.resetAllMocks(); + vi.resetModules(); + vi.unstubAllGlobals(); + }); + + it("configures client with correct defaults and headers", async () => { + const fetchMock = createFetchMock(); + vi.stubGlobal("fetch", fetchMock); + + const { createVoyageEmbeddingProvider } = await import("./embeddings-voyage.js"); + const authModule = await import("../agents/model-auth.js"); + + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "voyage-key-123", + mode: "api-key", + source: "test", + }); + + const result = await createVoyageEmbeddingProvider({ + config: {} as never, + provider: "voyage", + model: "voyage-4-large", + fallback: "none", + }); + + await result.provider.embedQuery("test query"); + + expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith( + expect.objectContaining({ provider: "voyage" }), + ); + + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://api.voyageai.com/v1/embeddings"); + + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer voyage-key-123"); + expect(headers["Content-Type"]).toBe("application/json"); + + const body = JSON.parse(init?.body as string); + expect(body).toEqual({ + model: "voyage-4-large", + input: ["test query"], + }); + }); + + it("respects remote overrides for baseUrl and apiKey", async () => { + const fetchMock = createFetchMock(); + vi.stubGlobal("fetch", fetchMock); + + const { createVoyageEmbeddingProvider } = await import("./embeddings-voyage.js"); + + const result = await createVoyageEmbeddingProvider({ + config: {} as never, + provider: "voyage", + model: "voyage-4-lite", + fallback: "none", + remote: { + baseUrl: "https://proxy.example.com", + apiKey: "remote-override-key", + headers: { "X-Custom": "123" }, + }, + }); + + await result.provider.embedQuery("test"); + + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://proxy.example.com/embeddings"); + + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer remote-override-key"); + expect(headers["X-Custom"]).toBe("123"); + }); + + it("normalizes model names", async () => { + const { normalizeVoyageModel } = await import("./embeddings-voyage.js"); + expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2"); + expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large"); + expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite"); + expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default + }); +}); diff --git a/src/memory/embeddings-voyage.ts b/src/memory/embeddings-voyage.ts new file mode 100644 index 0000000000..a962b17805 --- /dev/null +++ b/src/memory/embeddings-voyage.ts @@ -0,0 +1,86 @@ +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type VoyageEmbeddingClient = { + baseUrl: string; + headers: Record; + model: string; +}; + +export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large"; +const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1"; + +export function normalizeVoyageModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) return DEFAULT_VOYAGE_EMBEDDING_MODEL; + if (trimmed.startsWith("voyage/")) return trimmed.slice("voyage/".length); + return trimmed; +} + +export async function createVoyageEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> { + const client = await resolveVoyageEmbeddingClient(options); + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[]): Promise => { + if (input.length === 0) return []; + const res = await fetch(url, { + method: "POST", + headers: client.headers, + body: JSON.stringify({ model: client.model, input }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`voyage embeddings failed: ${res.status} ${text}`); + } + const payload = (await res.json()) as { + data?: Array<{ embedding?: number[] }>; + }; + const data = payload.data ?? []; + return data.map((entry) => entry.embedding ?? []); + }; + + return { + provider: { + id: "voyage", + model: client.model, + embedQuery: async (text) => { + const [vec] = await embed([text]); + return vec ?? []; + }, + embedBatch: embed, + }, + client, + }; +} + +export async function resolveVoyageEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const remote = options.remote; + const remoteApiKey = remote?.apiKey?.trim(); + const remoteBaseUrl = remote?.baseUrl?.trim(); + + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "voyage", + cfg: options.config, + agentDir: options.agentDir, + }), + "voyage", + ); + + const providerConfig = options.config.models?.providers?.voyage; + const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_VOYAGE_BASE_URL; + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + ...headerOverrides, + }; + const model = normalizeVoyageModel(options.model); + return { baseUrl, headers, model }; +} diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index a2783a1349..6b78c3d738 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -4,6 +4,7 @@ import type { OpenClawConfig } from "../config/config.js"; import { resolveUserPath } from "../utils.js"; import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js"; import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; +import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js"; import { importNodeLlamaCpp } from "./node-llama.js"; function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { @@ -17,6 +18,7 @@ function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; +export type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; export type EmbeddingProvider = { id: string; @@ -27,24 +29,25 @@ export type EmbeddingProvider = { export type EmbeddingProviderResult = { provider: EmbeddingProvider; - requestedProvider: "openai" | "local" | "gemini" | "auto"; - fallbackFrom?: "openai" | "local" | "gemini"; + requestedProvider: "openai" | "local" | "gemini" | "voyage" | "auto"; + fallbackFrom?: "openai" | "local" | "gemini" | "voyage"; fallbackReason?: string; openAi?: OpenAiEmbeddingClient; gemini?: GeminiEmbeddingClient; + voyage?: VoyageEmbeddingClient; }; export type EmbeddingProviderOptions = { config: OpenClawConfig; agentDir?: string; - provider: "openai" | "local" | "gemini" | "auto"; + provider: "openai" | "local" | "gemini" | "voyage" | "auto"; remote?: { baseUrl?: string; apiKey?: string; headers?: Record; }; model: string; - fallback: "openai" | "gemini" | "local" | "none"; + fallback: "openai" | "gemini" | "local" | "voyage" | "none"; local?: { modelPath?: string; modelCacheDir?: string; @@ -128,7 +131,7 @@ export async function createEmbeddingProvider( const requestedProvider = options.provider; const fallback = options.fallback; - const createProvider = async (id: "openai" | "local" | "gemini") => { + const createProvider = async (id: "openai" | "local" | "gemini" | "voyage") => { if (id === "local") { const provider = await createLocalEmbeddingProvider(options); return { provider }; @@ -137,11 +140,15 @@ export async function createEmbeddingProvider( const { provider, client } = await createGeminiEmbeddingProvider(options); return { provider, gemini: client }; } + if (id === "voyage") { + const { provider, client } = await createVoyageEmbeddingProvider(options); + return { provider, voyage: client }; + } const { provider, client } = await createOpenAiEmbeddingProvider(options); return { provider, openAi: client }; }; - const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini") => + const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini" | "voyage") => provider === "local" ? formatLocalSetupError(err) : formatError(err); if (requestedProvider === "auto") { @@ -157,7 +164,7 @@ export async function createEmbeddingProvider( } } - for (const provider of ["openai", "gemini"] as const) { + for (const provider of ["openai", "gemini", "voyage"] as const) { try { const result = await createProvider(provider); return { ...result, requestedProvider }; @@ -240,6 +247,7 @@ function formatLocalSetupError(err: unknown): string { : null, "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", 'Or set agents.defaults.memorySearch.provider = "openai" (remote).', + 'Or set agents.defaults.memorySearch.provider = "voyage" (remote).', ] .filter(Boolean) .join("\n"); diff --git a/src/memory/internal.ts b/src/memory/internal.ts index cbdb7c6c6e..5cb1bc8a26 100644 --- a/src/memory/internal.ts +++ b/src/memory/internal.ts @@ -275,3 +275,33 @@ export function cosineSimilarity(a: number[], b: number[]): number { } return dot / (Math.sqrt(normA) * Math.sqrt(normB)); } + +export async function runWithConcurrency( + tasks: Array<() => Promise>, + limit: number, +): Promise { + if (tasks.length === 0) return []; + const resolvedLimit = Math.max(1, Math.min(limit, tasks.length)); + const results: T[] = Array.from({ length: tasks.length }); + let next = 0; + let firstError: unknown = null; + + const workers = Array.from({ length: resolvedLimit }, async () => { + while (true) { + if (firstError) return; + const index = next; + next += 1; + if (index >= tasks.length) return; + try { + results[index] = await tasks[index](); + } catch (err) { + firstError = err; + return; + } + } + }); + + await Promise.allSettled(workers); + if (firstError) throw firstError; + return results; +} diff --git a/src/memory/manager.ts b/src/memory/manager.ts index 3dd290b105..b772d3fda4 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -26,14 +26,17 @@ import { type OpenAiBatchRequest, runOpenAiEmbeddingBatches, } from "./batch-openai.js"; +import { type VoyageBatchRequest, runVoyageEmbeddingBatches } from "./batch-voyage.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; import { DEFAULT_OPENAI_EMBEDDING_MODEL } from "./embeddings-openai.js"; +import { DEFAULT_VOYAGE_EMBEDDING_MODEL } from "./embeddings-voyage.js"; import { createEmbeddingProvider, type EmbeddingProvider, type EmbeddingProviderResult, type GeminiEmbeddingClient, type OpenAiEmbeddingClient, + type VoyageEmbeddingClient, } from "./embeddings.js"; import { bm25RankToScore, buildFtsQuery, mergeHybridResults } from "./hybrid.js"; import { @@ -47,6 +50,7 @@ import { type MemoryChunk, type MemoryFileEntry, parseEmbedding, + runWithConcurrency, } from "./internal.js"; import { searchKeyword, searchVector } from "./manager-search.js"; import { ensureMemoryIndexSchema } from "./memory-schema.js"; @@ -112,11 +116,12 @@ export class MemoryIndexManager implements MemorySearchManager { private readonly workspaceDir: string; private readonly settings: ResolvedMemorySearchConfig; private provider: EmbeddingProvider; - private readonly requestedProvider: "openai" | "local" | "gemini" | "auto"; - private fallbackFrom?: "openai" | "local" | "gemini"; + private readonly requestedProvider: "openai" | "local" | "gemini" | "voyage" | "auto"; + private fallbackFrom?: "openai" | "local" | "gemini" | "voyage"; private fallbackReason?: string; private openAi?: OpenAiEmbeddingClient; private gemini?: GeminiEmbeddingClient; + private voyage?: VoyageEmbeddingClient; private batch: { enabled: boolean; wait: boolean; @@ -217,6 +222,7 @@ export class MemoryIndexManager implements MemorySearchManager { this.fallbackReason = params.providerResult.fallbackReason; this.openAi = params.providerResult.openAi; this.gemini = params.providerResult.gemini; + this.voyage = params.providerResult.voyage; this.sources = new Set(params.settings.sources); this.db = this.openDatabase(); this.providerKey = this.computeProviderKey(); @@ -1109,7 +1115,7 @@ export class MemoryIndexManager implements MemorySearchManager { }); } }); - await this.runWithConcurrency(tasks, this.getIndexConcurrency()); + await runWithConcurrency(tasks, this.getIndexConcurrency()); const staleRows = this.db .prepare(`SELECT path FROM files WHERE source = ?`) @@ -1206,7 +1212,7 @@ export class MemoryIndexManager implements MemorySearchManager { }); } }); - await this.runWithConcurrency(tasks, this.getIndexConcurrency()); + await runWithConcurrency(tasks, this.getIndexConcurrency()); const staleRows = this.db .prepare(`SELECT path FROM files WHERE source = ?`) @@ -1346,7 +1352,8 @@ export class MemoryIndexManager implements MemorySearchManager { const enabled = Boolean( batch?.enabled && ((this.openAi && this.provider.id === "openai") || - (this.gemini && this.provider.id === "gemini")), + (this.gemini && this.provider.id === "gemini") || + (this.voyage && this.provider.id === "voyage")), ); return { enabled, @@ -1365,14 +1372,16 @@ export class MemoryIndexManager implements MemorySearchManager { if (this.fallbackFrom) { return false; } - const fallbackFrom = this.provider.id as "openai" | "gemini" | "local"; + const fallbackFrom = this.provider.id as "openai" | "gemini" | "local" | "voyage"; const fallbackModel = fallback === "gemini" ? DEFAULT_GEMINI_EMBEDDING_MODEL : fallback === "openai" ? DEFAULT_OPENAI_EMBEDDING_MODEL - : this.settings.model; + : fallback === "voyage" + ? DEFAULT_VOYAGE_EMBEDDING_MODEL + : this.settings.model; const fallbackResult = await createEmbeddingProvider({ config: this.cfg, @@ -1389,6 +1398,7 @@ export class MemoryIndexManager implements MemorySearchManager { this.provider = fallbackResult.provider; this.openAi = fallbackResult.openAi; this.gemini = fallbackResult.gemini; + this.voyage = fallbackResult.voyage; this.providerKey = this.computeProviderKey(); this.batch = this.resolveBatchConfig(); log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason }); @@ -1865,9 +1875,82 @@ export class MemoryIndexManager implements MemorySearchManager { if (this.provider.id === "gemini" && this.gemini) { return this.embedChunksWithGeminiBatch(chunks, entry, source); } + if (this.provider.id === "voyage" && this.voyage) { + return this.embedChunksWithVoyageBatch(chunks, entry, source); + } return this.embedChunksInBatches(chunks); } + private async embedChunksWithVoyageBatch( + chunks: MemoryChunk[], + entry: MemoryFileEntry | SessionFileEntry, + source: MemorySource, + ): Promise { + const voyage = this.voyage; + if (!voyage) { + return this.embedChunksInBatches(chunks); + } + if (chunks.length === 0) return []; + const cached = this.loadEmbeddingCache(chunks.map((chunk) => chunk.hash)); + const embeddings: number[][] = Array.from({ length: chunks.length }, () => []); + const missing: Array<{ index: number; chunk: MemoryChunk }> = []; + + for (let i = 0; i < chunks.length; i += 1) { + const chunk = chunks[i]; + const hit = chunk?.hash ? cached.get(chunk.hash) : undefined; + if (hit && hit.length > 0) { + embeddings[i] = hit; + } else if (chunk) { + missing.push({ index: i, chunk }); + } + } + + if (missing.length === 0) return embeddings; + + const requests: VoyageBatchRequest[] = []; + const mapping = new Map(); + for (const item of missing) { + const chunk = item.chunk; + const customId = hashText( + `${source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${item.index}`, + ); + mapping.set(customId, { index: item.index, hash: chunk.hash }); + requests.push({ + custom_id: customId, + body: { + input: chunk.text, + }, + }); + } + const batchResult = await this.runBatchWithFallback({ + provider: "voyage", + run: async () => + await runVoyageEmbeddingBatches({ + client: voyage, + agentId: this.agentId, + requests, + wait: this.batch.wait, + concurrency: this.batch.concurrency, + pollIntervalMs: this.batch.pollIntervalMs, + timeoutMs: this.batch.timeoutMs, + debug: (message, data) => log.debug(message, { ...data, source, chunks: chunks.length }), + }), + fallback: async () => await this.embedChunksInBatches(chunks), + }); + if (Array.isArray(batchResult)) return batchResult; + const byCustomId = batchResult; + + const toCache: Array<{ hash: string; embedding: number[] }> = []; + for (const [customId, embedding] of byCustomId.entries()) { + const mapped = mapping.get(customId); + if (!mapped) continue; + embeddings[mapped.index] = embedding; + toCache.push({ hash: mapped.hash, embedding }); + } + this.upsertEmbeddingCache(toCache); + return embeddings; + } + private async embedChunksWithOpenAiBatch( chunks: MemoryChunk[], entry: MemoryFileEntry | SessionFileEntry, @@ -2108,41 +2191,6 @@ export class MemoryIndexManager implements MemorySearchManager { } } - private async runWithConcurrency(tasks: Array<() => Promise>, limit: number): Promise { - if (tasks.length === 0) { - return []; - } - const resolvedLimit = Math.max(1, Math.min(limit, tasks.length)); - const results: T[] = Array.from({ length: tasks.length }); - let next = 0; - let firstError: unknown = null; - - const workers = Array.from({ length: resolvedLimit }, async () => { - while (true) { - if (firstError) { - return; - } - const index = next; - next += 1; - if (index >= tasks.length) { - return; - } - try { - results[index] = await tasks[index](); - } catch (err) { - firstError = err; - return; - } - } - }); - - await Promise.allSettled(workers); - if (firstError) { - throw firstError; - } - return results; - } - private async withBatchFailureLock(fn: () => Promise): Promise { let release: () => void; const wait = this.batchFailureLock;