From a3a4996adb9abdd6d22ae2430906b890e107d55a Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 18 Jan 2026 09:09:13 +0000 Subject: [PATCH] feat: add gemini memory embeddings --- CHANGELOG.md | 17 ++- docs/concepts/memory.md | 25 +++- docs/gateway/configuration-examples.md | 5 +- src/agents/memory-search.test.ts | 20 +++ src/agents/memory-search.ts | 12 +- src/config/schema.ts | 4 +- src/config/types.tools.ts | 2 +- src/config/zod-schema.agent-runtime.ts | 2 +- src/memory/embeddings.test.ts | 33 +++++ src/memory/embeddings.ts | 129 +++++++++++++++++- src/memory/manager.atomic-reindex.test.ts | 91 +++++++++++++ src/memory/manager.ts | 154 +++++++++++++++++++++- src/memory/provider-key.ts | 12 ++ 13 files changed, 482 insertions(+), 24 deletions(-) create mode 100644 src/memory/manager.atomic-reindex.test.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index b20b1fd191..1407de19d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,22 +9,35 @@ Docs: https://docs.clawd.bot - macOS: stop syncing Peekaboo as a git submodule in postinstall. - Swabble: use the tagged Commander Swift package release. - CLI: add `clawdbot acp client` interactive ACP harness for debugging. +- Memory: add native Gemini embeddings provider for memory search. (#1151) — thanks @gumadeiras. ### Fixes - Auth profiles: keep auto-pinned preference while allowing rotation on failover; user pins stay locked. (#1138) — thanks @cheeeee. - macOS: avoid touching launchd in Remote over SSH so quitting the app no longer disables the remote gateway. (#1105) +- Memory: index atomically so failed reindex preserves the previous memory database. (#1151) — thanks @gumadeiras. + ## 2026.1.18-3 ### Changes - Exec: add host/security/ask routing for gateway + node exec. +- Exec: add `/exec` directive for per-session exec defaults (host/security/ask/node). - macOS: migrate exec approvals to `~/.clawdbot/exec-approvals.json` with per-agent allowlists and skill auto-allow toggle. - macOS: add approvals socket UI server + node exec lifecycle events. -- Plugins: ship Discord/Slack/Telegram/Signal/WhatsApp as bundled channel plugins via the shared SDK (iMessage now bundled + opt-in). +- Nodes: add headless node host (`clawdbot node start`) for `system.run`/`system.which`. +- Nodes: add node daemon service install/status/start/stop/restart. +- Bridge: add `skills.bins` RPC to support node host auto-allow skill bins. +- Slash commands: replace `/cost` with `/usage off|tokens|full` to control per-response usage footer; `/usage` no longer aliases `/status`. (Supersedes #1140) — thanks @Nachx639. +- Sessions: add daily reset policy with per-type overrides and idle windows (default 4am local), preserving legacy idle-only configs. (#1146) — thanks @austinm911. +- Agents: auto-inject local image references for vision models and avoid reloading history images. (#1098) — thanks @tyler6204. - Docs: refresh exec/elevated/exec-approvals docs for the new flow. https://docs.clawd.bot/tools/exec-approvals +- Docs: add node host CLI + update exec approvals/bridge protocol docs. https://docs.clawd.bot/cli/node +- ACP: add experimental ACP support for IDE integrations (`clawdbot acp`). Thanks @visionik. ### Fixes +- Exec approvals: enforce allowlist when ask is off; prefer raw command for node approvals/events. - Tools: return a companion-app-required message when node exec is requested with no paired node. -- Tests: avoid extension imports when wiring plugin registries in unit tests. +- Streaming: emit assistant deltas for OpenAI-compatible SSE chunks. (#1147) — thanks @alauppe. +- Model fallback: treat timeout aborts as failover while preserving user aborts. (#1137) — thanks @cheeeee. ## 2026.1.18-2 diff --git a/docs/concepts/memory.md b/docs/concepts/memory.md index 8e7eade9d5..2183e499e0 100644 --- a/docs/concepts/memory.md +++ b/docs/concepts/memory.md @@ -89,7 +89,26 @@ OAuth only covers chat/completions and does **not** satisfy embeddings for memory search. When using a custom OpenAI-compatible endpoint, set `memorySearch.remote.apiKey` (and optional `memorySearch.remote.headers`). -If you want to use a **custom OpenAI-compatible endpoint** (like Gemini, OpenRouter, or a proxy), +If you want to use **Gemini embeddings** directly, set the provider to `gemini`: + +```json5 +agents: { + defaults: { + memorySearch: { + provider: "gemini", + model: "gemini-embedding-001", // default + remote: { + apiKey: "${GEMINI_API_KEY}" + } + } + } +} +``` + +Gemini uses `GEMINI_API_KEY` (or `models.providers.google.apiKey`). Override +`memorySearch.remote.baseUrl` to point at a custom Gemini-compatible endpoint. + +If you want to use a **custom OpenAI-compatible endpoint** (like OpenRouter or a proxy), you can use the `remote` configuration: ```json5 @@ -99,8 +118,8 @@ agents: { provider: "openai", model: "text-embedding-3-small", remote: { - baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/", - apiKey: "YOUR_GEMINI_API_KEY", + baseUrl: "https://proxy.example/v1", + apiKey: "YOUR_PROXY_KEY", headers: { "X-Custom-Header": "value" } } } diff --git a/docs/gateway/configuration-examples.md b/docs/gateway/configuration-examples.md index 1075064b67..793ece4126 100644 --- a/docs/gateway/configuration-examples.md +++ b/docs/gateway/configuration-examples.md @@ -261,10 +261,9 @@ Save to `~/.clawdbot/clawdbot.json` and you can DM the bot from that number. ackMaxChars: 300 }, memorySearch: { - provider: "openai", - model: "text-embedding-004", + provider: "gemini", + model: "gemini-embedding-001", remote: { - baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/", apiKey: "${GEMINI_API_KEY}" } }, diff --git a/src/agents/memory-search.test.ts b/src/agents/memory-search.test.ts index ef58a91162..d5271ea540 100644 --- a/src/agents/memory-search.test.ts +++ b/src/agents/memory-search.test.ts @@ -101,6 +101,26 @@ describe("memory search config", () => { expect(resolved?.remote).toBeUndefined(); }); + it("includes remote defaults for gemini without overrides", () => { + const cfg = { + agents: { + defaults: { + memorySearch: { + provider: "gemini", + }, + }, + }, + }; + const resolved = resolveMemorySearchConfig(cfg, "main"); + expect(resolved?.remote?.batch).toEqual({ + enabled: true, + wait: true, + concurrency: 2, + pollIntervalMs: 2000, + timeoutMinutes: 60, + }); + }); + it("merges remote defaults with agent overrides", () => { const cfg = { agents: { diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index 2f4a5db604..d8fc8a5623 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -9,7 +9,7 @@ import { resolveAgentConfig } from "./agent-scope.js"; export type ResolvedMemorySearchConfig = { enabled: boolean; sources: Array<"memory" | "sessions">; - provider: "openai" | "local"; + provider: "openai" | "gemini" | "local"; remote?: { baseUrl?: string; apiKey?: string; @@ -66,7 +66,8 @@ export type ResolvedMemorySearchConfig = { }; }; -const DEFAULT_MODEL = "text-embedding-3-small"; +const DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; +const DEFAULT_GEMINI_MODEL = "gemini-embedding-001"; const DEFAULT_CHUNK_TOKENS = 400; const DEFAULT_CHUNK_OVERLAP = 80; const DEFAULT_WATCH_DEBOUNCE_MS = 1500; @@ -111,7 +112,7 @@ function mergeConfig( overrides?.experimental?.sessionMemory ?? defaults?.experimental?.sessionMemory ?? false; const provider = overrides?.provider ?? defaults?.provider ?? "openai"; const hasRemote = Boolean(defaults?.remote || overrides?.remote); - const includeRemote = hasRemote || provider === "openai"; + const includeRemote = hasRemote || provider === "openai" || provider === "gemini"; const batch = { enabled: overrides?.remote?.batch?.enabled ?? defaults?.remote?.batch?.enabled ?? true, wait: overrides?.remote?.batch?.wait ?? defaults?.remote?.batch?.wait ?? true, @@ -133,7 +134,10 @@ function mergeConfig( } : undefined; const fallback = overrides?.fallback ?? defaults?.fallback ?? "openai"; - const model = overrides?.model ?? defaults?.model ?? DEFAULT_MODEL; + const model = + overrides?.model ?? + defaults?.model ?? + (provider === "gemini" ? DEFAULT_GEMINI_MODEL : DEFAULT_OPENAI_MODEL); const local = { modelPath: overrides?.local?.modelPath ?? defaults?.local?.modelPath, modelCacheDir: overrides?.local?.modelCacheDir ?? defaults?.local?.modelCacheDir, diff --git a/src/config/schema.ts b/src/config/schema.ts index 721b3ab395..06a1938ae6 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -374,9 +374,9 @@ const FIELD_HELP: Record = { 'Sources to index for memory search (default: ["memory"]; add "sessions" to include session transcripts).', "agents.defaults.memorySearch.experimental.sessionMemory": "Enable experimental session transcript indexing for memory search (default: false).", - "agents.defaults.memorySearch.provider": 'Embedding provider ("openai" or "local").', + "agents.defaults.memorySearch.provider": 'Embedding provider ("openai", "gemini", or "local").', "agents.defaults.memorySearch.remote.baseUrl": - "Custom OpenAI-compatible base URL (e.g. for Gemini/OpenRouter proxies).", + "Custom base URL for remote embeddings (OpenAI-compatible proxies or Gemini overrides).", "agents.defaults.memorySearch.remote.apiKey": "Custom API key for the remote embedding provider.", "agents.defaults.memorySearch.remote.headers": "Extra headers for remote embeddings (merged; remote overrides OpenAI headers).", diff --git a/src/config/types.tools.ts b/src/config/types.tools.ts index 456060d061..4336de27d5 100644 --- a/src/config/types.tools.ts +++ b/src/config/types.tools.ts @@ -170,7 +170,7 @@ export type MemorySearchConfig = { sessionMemory?: boolean; }; /** Embedding provider mode. */ - provider?: "openai" | "local"; + provider?: "openai" | "gemini" | "local"; remote?: { baseUrl?: string; apiKey?: string; diff --git a/src/config/zod-schema.agent-runtime.ts b/src/config/zod-schema.agent-runtime.ts index 33d647ea29..7afd6c901c 100644 --- a/src/config/zod-schema.agent-runtime.ts +++ b/src/config/zod-schema.agent-runtime.ts @@ -218,7 +218,7 @@ export const MemorySearchSchema = z sessionMemory: z.boolean().optional(), }) .optional(), - provider: z.union([z.literal("openai"), z.literal("local")]).optional(), + provider: z.union([z.literal("openai"), z.literal("gemini"), z.literal("local")]).optional(), remote: z .object({ baseUrl: z.string().optional(), diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index bd4e47be6c..fb51cbee37 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -107,6 +107,39 @@ describe("embedding provider remote overrides", () => { const headers = (fetchMock.mock.calls[0]?.[1]?.headers as Record) ?? {}; expect(headers.Authorization).toBe("Bearer provider-key"); }); + + it("uses gemini embedContent endpoint with x-goog-api-key", async () => { + const fetchMock = vi.fn(async () => ({ + ok: true, + status: 200, + json: async () => ({ embedding: { values: [1, 2, 3] } }), + })) as unknown as typeof fetch; + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "gemini-key", + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "gemini", + remote: { + baseUrl: "https://gemini.example/v1beta", + }, + model: "gemini-embedding-001", + fallback: "openai", + }); + + await result.provider.embedQuery("hello"); + + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://gemini.example/v1beta/models/gemini-embedding-001:embedContent"); + const headers = (init?.headers ?? {}) as Record; + expect(headers["x-goog-api-key"]).toBe("gemini-key"); + expect(headers["Content-Type"]).toBe("application/json"); + }); }); describe("embedding provider local fallback", () => { diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index 3fe235beb8..9a88a8119b 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -12,10 +12,11 @@ export type EmbeddingProvider = { export type EmbeddingProviderResult = { provider: EmbeddingProvider; - requestedProvider: "openai" | "local"; + requestedProvider: "openai" | "gemini" | "local"; fallbackFrom?: "local"; fallbackReason?: string; openAi?: OpenAiEmbeddingClient; + gemini?: GeminiEmbeddingClient; }; export type OpenAiEmbeddingClient = { @@ -24,10 +25,16 @@ export type OpenAiEmbeddingClient = { model: string; }; +export type GeminiEmbeddingClient = { + baseUrl: string; + headers: Record; + model: string; +}; + export type EmbeddingProviderOptions = { config: ClawdbotConfig; agentDir?: string; - provider: "openai" | "local"; + provider: "openai" | "gemini" | "local"; remote?: { baseUrl?: string; apiKey?: string; @@ -43,6 +50,8 @@ export type EmbeddingProviderOptions = { const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"; const DEFAULT_LOCAL_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf"; +const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; +const DEFAULT_GEMINI_MODEL = "gemini-embedding-001"; function normalizeOpenAiModel(model: string): string { const trimmed = model.trim(); @@ -51,6 +60,14 @@ function normalizeOpenAiModel(model: string): string { return trimmed; } +function normalizeGeminiModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) return DEFAULT_GEMINI_MODEL; + if (trimmed.startsWith("models/")) return trimmed.slice("models/".length); + if (trimmed.startsWith("google/")) return trimmed.slice("google/".length); + return trimmed; +} + async function createOpenAiEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> { @@ -89,6 +106,83 @@ async function createOpenAiEmbeddingProvider( }; } +function extractGeminiEmbeddingValues(entry: unknown): number[] { + if (!entry || typeof entry !== "object") return []; + const record = entry as { values?: unknown; embedding?: { values?: unknown } }; + const values = record.values ?? record.embedding?.values; + if (!Array.isArray(values)) return []; + return values.filter((value): value is number => typeof value === "number"); +} + +function parseGeminiEmbeddings(payload: unknown): number[][] { + if (!payload || typeof payload !== "object") return []; + const data = payload as { embedding?: unknown; embeddings?: unknown[] }; + if (Array.isArray(data.embeddings)) { + return data.embeddings.map((entry) => extractGeminiEmbeddingValues(entry)); + } + if (data.embedding) { + return [extractGeminiEmbeddingValues(data.embedding)]; + } + return []; +} + +async function createGeminiEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> { + const client = await resolveGeminiEmbeddingClient(options); + const baseUrl = client.baseUrl.replace(/\/$/, ""); + const model = `models/${client.model}`; + + const embedContent = async (input: string): Promise => { + const res = await fetch(`${baseUrl}/${model}:embedContent`, { + method: "POST", + headers: client.headers, + body: JSON.stringify({ + model, + content: { parts: [{ text: input }] }, + }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini embeddings failed: ${res.status} ${text}`); + } + const payload = await res.json(); + const embeddings = parseGeminiEmbeddings(payload); + return embeddings[0] ?? []; + }; + + const embedBatch = async (input: string[]): Promise => { + if (input.length === 0) return []; + const res = await fetch(`${baseUrl}/${model}:batchEmbedContents`, { + method: "POST", + headers: client.headers, + body: JSON.stringify({ + requests: input.map((text) => ({ + model, + content: { parts: [{ text }] }, + })), + }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini embeddings failed: ${res.status} ${text}`); + } + const payload = await res.json(); + const embeddings = parseGeminiEmbeddings(payload); + return embeddings; + }; + + return { + provider: { + id: "gemini", + model: client.model, + embedQuery: embedContent, + embedBatch, + }, + client, + }; +} + async function resolveOpenAiEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { @@ -116,6 +210,33 @@ async function resolveOpenAiEmbeddingClient( return { baseUrl, headers, model }; } +async function resolveGeminiEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const remote = options.remote; + const remoteApiKey = remote?.apiKey?.trim(); + const remoteBaseUrl = remote?.baseUrl?.trim(); + + const { apiKey } = remoteApiKey + ? { apiKey: remoteApiKey } + : await resolveApiKeyForProvider({ + provider: "google", + cfg: options.config, + agentDir: options.agentDir, + }); + + const providerConfig = options.config.models?.providers?.google; + const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL; + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + "Content-Type": "application/json", + "x-goog-api-key": apiKey, + ...headerOverrides, + }; + const model = normalizeGeminiModel(options.model); + return { baseUrl, headers, model }; +} + async function createLocalEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise { @@ -168,6 +289,10 @@ export async function createEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise { const requestedProvider = options.provider; + if (options.provider === "gemini") { + const { provider, client } = await createGeminiEmbeddingProvider(options); + return { provider, requestedProvider, gemini: client }; + } if (options.provider === "local") { try { const provider = await createLocalEmbeddingProvider(options); diff --git a/src/memory/manager.atomic-reindex.test.ts b/src/memory/manager.atomic-reindex.test.ts new file mode 100644 index 0000000000..801bdba4b3 --- /dev/null +++ b/src/memory/manager.atomic-reindex.test.ts @@ -0,0 +1,91 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; + +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +import { getMemorySearchManager, type MemoryIndexManager } from "./index.js"; + +let shouldFail = false; + +vi.mock("chokidar", () => ({ + default: { + watch: vi.fn(() => ({ + on: vi.fn(), + close: vi.fn(async () => undefined), + })), + }, +})); + +vi.mock("./embeddings.js", () => { + return { + createEmbeddingProvider: async () => ({ + requestedProvider: "openai", + provider: { + id: "mock", + model: "mock-embed", + embedQuery: async () => [0.1, 0.2, 0.3], + embedBatch: async (texts: string[]) => { + if (shouldFail) { + throw new Error("embedding failure"); + } + return texts.map((_, index) => [index + 1, 0, 0]); + }, + }, + }), + }; +}); + +describe("memory manager atomic reindex", () => { + let workspaceDir: string; + let indexPath: string; + let manager: MemoryIndexManager | null = null; + + beforeEach(async () => { + shouldFail = false; + workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-mem-")); + indexPath = path.join(workspaceDir, "index.sqlite"); + await fs.mkdir(path.join(workspaceDir, "memory")); + await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "Hello memory."); + }); + + afterEach(async () => { + if (manager) { + await manager.close(); + manager = null; + } + await fs.rm(workspaceDir, { recursive: true, force: true }); + }); + + it("keeps the prior index when a full reindex fails", async () => { + const cfg = { + agents: { + defaults: { + workspace: workspaceDir, + memorySearch: { + provider: "openai", + model: "mock-embed", + store: { path: indexPath }, + sync: { watch: false, onSessionStart: false, onSearch: false }, + }, + }, + list: [{ id: "main", default: true }], + }, + }; + + const result = await getMemorySearchManager({ cfg, agentId: "main" }); + expect(result.manager).not.toBeNull(); + if (!result.manager) throw new Error("manager missing"); + manager = result.manager; + + await manager.sync({ force: true }); + const before = await manager.search("Hello"); + expect(before.length).toBeGreaterThan(0); + + shouldFail = true; + await expect(manager.sync({ force: true })).rejects.toThrow("embedding failure"); + + const after = await manager.search("Hello"); + expect(after.length).toBeGreaterThan(0); + }); +}); diff --git a/src/memory/manager.ts b/src/memory/manager.ts index 637b231274..e5cb9c1cae 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -16,6 +16,7 @@ import { createEmbeddingProvider, type EmbeddingProvider, type EmbeddingProviderResult, + type GeminiEmbeddingClient, type OpenAiEmbeddingClient, } from "./embeddings.js"; import { @@ -104,9 +105,10 @@ export class MemoryIndexManager { private readonly workspaceDir: string; private readonly settings: ResolvedMemorySearchConfig; private readonly provider: EmbeddingProvider; - private readonly requestedProvider: "openai" | "local"; + private readonly requestedProvider: "openai" | "gemini" | "local"; private readonly fallbackReason?: string; private readonly openAi?: OpenAiEmbeddingClient; + private readonly gemini?: GeminiEmbeddingClient; private readonly batch: { enabled: boolean; wait: boolean; @@ -114,7 +116,7 @@ export class MemoryIndexManager { pollIntervalMs: number; timeoutMs: number; }; - private readonly db: DatabaseSync; + private db: DatabaseSync; private readonly sources: Set; private readonly providerKey: string; private readonly cache: { enabled: boolean; maxEntries?: number }; @@ -142,6 +144,7 @@ export class MemoryIndexManager { private sessionsDirtyFiles = new Set(); private sessionWarm = new Set(); private syncing: Promise | null = null; + private readonly allowAtomicReindex: boolean; static async get(params: { cfg: ClawdbotConfig; @@ -182,6 +185,7 @@ export class MemoryIndexManager { workspaceDir: string; settings: ResolvedMemorySearchConfig; providerResult: EmbeddingProviderResult; + options?: { allowAtomicReindex?: boolean; enableBackgroundSync?: boolean }; }) { this.cacheKey = params.cacheKey; this.cfg = params.cfg; @@ -192,6 +196,8 @@ export class MemoryIndexManager { this.requestedProvider = params.providerResult.requestedProvider; this.fallbackReason = params.providerResult.fallbackReason; this.openAi = params.providerResult.openAi; + this.gemini = params.providerResult.gemini; + this.allowAtomicReindex = params.options?.allowAtomicReindex ?? true; this.sources = new Set(params.settings.sources); this.db = this.openDatabase(); this.providerKey = computeEmbeddingProviderKey({ @@ -200,6 +206,13 @@ export class MemoryIndexManager { openAi: this.openAi ? { baseUrl: this.openAi.baseUrl, model: this.openAi.model, headers: this.openAi.headers } : undefined, + gemini: this.gemini + ? { + baseUrl: this.gemini.baseUrl, + model: this.gemini.model, + headers: this.gemini.headers, + } + : undefined, }); this.cache = { enabled: params.settings.cache.enabled, @@ -216,9 +229,12 @@ export class MemoryIndexManager { if (meta?.vectorDims) { this.vector.dims = meta.vectorDims; } - this.ensureWatcher(); - this.ensureSessionListener(); - this.ensureIntervalSync(); + const enableBackgroundSync = params.options?.enableBackgroundSync ?? true; + if (enableBackgroundSync) { + this.ensureWatcher(); + this.ensureSessionListener(); + this.ensureIntervalSync(); + } this.dirty = this.sources.has("memory"); if (this.sources.has("sessions")) { this.sessionsDirty = true; @@ -782,7 +798,7 @@ export class MemoryIndexManager { force?: boolean; progress?: (update: MemorySyncProgressUpdate) => void; }) { - const progress = params?.progress ? this.createSyncProgress(params.progress) : undefined; + const progressCallback = params?.progress; const vectorReady = await this.ensureVectorReady(); const meta = this.readMeta(); const needsFullReindex = @@ -794,6 +810,12 @@ export class MemoryIndexManager { meta.chunkTokens !== this.settings.chunking.tokens || meta.chunkOverlap !== this.settings.chunking.overlap || (vectorReady && !meta?.vectorDims); + if (needsFullReindex && this.allowAtomicReindex) { + await this.runAtomicReindex({ reason: params?.reason, progress: progressCallback }); + return; + } + + const progress = progressCallback ? this.createSyncProgress(progressCallback) : undefined; if (needsFullReindex) { this.resetIndex(); } @@ -833,6 +855,126 @@ export class MemoryIndexManager { } } + private createScratchManager(tempPath: string): MemoryIndexManager { + const scratchSettings: ResolvedMemorySearchConfig = { + ...this.settings, + store: { + ...this.settings.store, + path: tempPath, + }, + sync: { + ...this.settings.sync, + watch: false, + intervalMinutes: 0, + }, + }; + + return new MemoryIndexManager({ + cacheKey: `${this.cacheKey}:scratch:${Date.now()}`, + cfg: this.cfg, + agentId: this.agentId, + workspaceDir: this.workspaceDir, + settings: scratchSettings, + providerResult: { + provider: this.provider, + requestedProvider: this.requestedProvider, + fallbackReason: this.fallbackReason, + openAi: this.openAi, + gemini: this.gemini, + }, + options: { + allowAtomicReindex: false, + enableBackgroundSync: false, + }, + }); + } + + private buildTempIndexPath(): string { + const basePath = resolveUserPath(this.settings.store.path); + const dir = path.dirname(basePath); + ensureDir(dir); + const stamp = `${Date.now()}-${Math.random().toString(16).slice(2, 10)}`; + return path.join(dir, `${path.basename(basePath)}.tmp-${stamp}`); + } + + private reopenDatabase() { + this.db = this.openDatabase(); + this.fts.available = false; + this.fts.loadError = undefined; + this.ensureSchema(); + this.vector.available = null; + this.vector.loadError = undefined; + this.vectorReady = null; + this.vector.dims = undefined; + const meta = this.readMeta(); + if (meta?.vectorDims) { + this.vector.dims = meta.vectorDims; + } + } + + private async swapIndexFile(tempPath: string): Promise { + const dbPath = resolveUserPath(this.settings.store.path); + const backupPath = `${dbPath}.bak-${Date.now()}`; + let hasBackup = false; + let shouldReopen = false; + + this.db.close(); + + try { + try { + await fs.rename(dbPath, backupPath); + hasBackup = true; + } catch (err) { + const code = (err as NodeJS.ErrnoException).code; + if (code !== "ENOENT") throw err; + } + await fs.rename(tempPath, dbPath); + shouldReopen = true; + if (hasBackup) { + await fs.rm(backupPath, { force: true }); + } + } catch (err) { + if (hasBackup) { + try { + await fs.rename(backupPath, dbPath); + shouldReopen = true; + } catch {} + } + if (!shouldReopen) { + try { + await fs.access(dbPath); + shouldReopen = true; + } catch {} + } + throw err; + } finally { + await fs.rm(tempPath, { force: true }); + if (shouldReopen) { + this.reopenDatabase(); + } + } + } + + private async runAtomicReindex(params: { + reason?: string; + progress?: (update: MemorySyncProgressUpdate) => void; + }) { + const tempPath = this.buildTempIndexPath(); + const scratch = this.createScratchManager(tempPath); + try { + await scratch.sync({ reason: params.reason, force: true, progress: params.progress }); + } catch (err) { + await fs.rm(tempPath, { force: true }); + throw err; + } finally { + await scratch.close().catch(() => undefined); + } + await this.swapIndexFile(tempPath); + this.dirty = false; + this.sessionsDirty = false; + this.sessionsDirtyFiles.clear(); + } + private resetIndex() { this.db.exec(`DELETE FROM files`); this.db.exec(`DELETE FROM chunks`); diff --git a/src/memory/provider-key.ts b/src/memory/provider-key.ts index 53877af77b..09485c0f2e 100644 --- a/src/memory/provider-key.ts +++ b/src/memory/provider-key.ts @@ -5,6 +5,7 @@ export function computeEmbeddingProviderKey(params: { providerId: string; providerModel: string; openAi?: { baseUrl: string; model: string; headers: Record }; + gemini?: { baseUrl: string; model: string; headers: Record }; }): string { if (params.openAi) { const headerNames = fingerprintHeaderNames(params.openAi.headers); @@ -17,5 +18,16 @@ export function computeEmbeddingProviderKey(params: { }), ); } + if (params.gemini) { + const headerNames = fingerprintHeaderNames(params.gemini.headers); + return hashText( + JSON.stringify({ + provider: "gemini", + baseUrl: params.gemini.baseUrl, + model: params.gemini.model, + headerNames, + }), + ); + } return hashText(JSON.stringify({ provider: params.providerId, model: params.providerModel })); }