From befa421a57cf4f83cb220e3032e1847a6d12f365 Mon Sep 17 00:00:00 2001 From: Shakker Date: Mon, 2 Feb 2026 23:46:34 +0000 Subject: [PATCH] Agents: flush pending tool results on drop --- src/agents/session-tool-result-guard.test.ts | 144 ++++++++++++------- src/agents/session-tool-result-guard.ts | 3 + 2 files changed, 99 insertions(+), 48 deletions(-) diff --git a/src/agents/session-tool-result-guard.test.ts b/src/agents/session-tool-result-guard.test.ts index b5780b2d05..9f0959b6a9 100644 --- a/src/agents/session-tool-result-guard.test.ts +++ b/src/agents/session-tool-result-guard.test.ts @@ -3,10 +3,14 @@ import { SessionManager } from "@mariozechner/pi-coding-agent"; import { describe, expect, it } from "vitest"; import { installSessionToolResultGuard } from "./session-tool-result-guard.js"; -const toolCallMessage = { +type AppendMessage = Parameters[0]; + +const asAppendMessage = (message: unknown) => message as AppendMessage; + +const toolCallMessage = asAppendMessage({ role: "assistant", content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }], -} satisfies AgentMessage; +}); describe("installSessionToolResultGuard", () => { it("inserts synthetic toolResult before non-tool message when pending", () => { @@ -14,11 +18,13 @@ describe("installSessionToolResultGuard", () => { installSessionToolResultGuard(sm); sm.appendMessage(toolCallMessage); - sm.appendMessage({ - role: "assistant", - content: [{ type: "text", text: "error" }], - stopReason: "error", - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "text", text: "error" }], + stopReason: "error", + }), + ); const entries = sm .getEntries() @@ -56,12 +62,14 @@ describe("installSessionToolResultGuard", () => { installSessionToolResultGuard(sm); sm.appendMessage(toolCallMessage); - sm.appendMessage({ - role: "toolResult", - toolCallId: "call_1", - content: [{ type: "text", text: "ok" }], - isError: false, - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "toolResult", + toolCallId: "call_1", + content: [{ type: "text", text: "ok" }], + isError: false, + }), + ); const messages = sm .getEntries() @@ -75,23 +83,29 @@ describe("installSessionToolResultGuard", () => { const sm = SessionManager.inMemory(); const guard = installSessionToolResultGuard(sm); - sm.appendMessage({ - role: "assistant", - content: [ - { type: "toolCall", id: "call_a", name: "one", arguments: {} }, - { type: "toolUse", id: "call_b", name: "two", arguments: {} }, - ], - } as AgentMessage); - sm.appendMessage({ - role: "toolResult", - toolUseId: "call_a", - content: [{ type: "text", text: "a" }], - isError: false, - } as AgentMessage); - sm.appendMessage({ - role: "assistant", - content: [{ type: "text", text: "after tools" }], - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [ + { type: "toolCall", id: "call_a", name: "one", arguments: {} }, + { type: "toolUse", id: "call_b", name: "two", arguments: {} }, + ], + }), + ); + sm.appendMessage( + asAppendMessage({ + role: "toolResult", + toolUseId: "call_a", + content: [{ type: "text", text: "a" }], + isError: false, + }), + ); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "text", text: "after tools" }], + }), + ); const messages = sm .getEntries() @@ -113,11 +127,13 @@ describe("installSessionToolResultGuard", () => { const guard = installSessionToolResultGuard(sm); sm.appendMessage(toolCallMessage); - sm.appendMessage({ - role: "assistant", - content: [{ type: "text", text: "hard error" }], - stopReason: "error", - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "text", text: "hard error" }], + stopReason: "error", + }), + ); expect(guard.getPendingIds()).toEqual([]); }); @@ -125,15 +141,19 @@ describe("installSessionToolResultGuard", () => { const sm = SessionManager.inMemory(); installSessionToolResultGuard(sm); - sm.appendMessage({ - role: "assistant", - content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }], - } as AgentMessage); - sm.appendMessage({ - role: "toolResult", - toolUseId: "use_1", - content: [{ type: "text", text: "ok" }], - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }], + }), + ); + sm.appendMessage( + asAppendMessage({ + role: "toolResult", + toolUseId: "use_1", + content: [{ type: "text", text: "ok" }], + }), + ); const messages = sm .getEntries() @@ -146,10 +166,12 @@ describe("installSessionToolResultGuard", () => { const sm = SessionManager.inMemory(); installSessionToolResultGuard(sm); - sm.appendMessage({ - role: "assistant", - content: [{ type: "toolCall", id: "call_1", name: "read" }], - } as AgentMessage); + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "toolCall", id: "call_1", name: "read" }], + }), + ); const messages = sm .getEntries() @@ -158,4 +180,30 @@ describe("installSessionToolResultGuard", () => { expect(messages).toHaveLength(0); }); + + it("flushes pending tool results when a sanitized assistant message is dropped", () => { + const sm = SessionManager.inMemory(); + installSessionToolResultGuard(sm); + + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }], + }), + ); + + sm.appendMessage( + asAppendMessage({ + role: "assistant", + content: [{ type: "toolCall", id: "call_2", name: "read" }], + }), + ); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]); + }); }); diff --git a/src/agents/session-tool-result-guard.ts b/src/agents/session-tool-result-guard.ts index d7810e2cef..ea0152ac76 100644 --- a/src/agents/session-tool-result-guard.ts +++ b/src/agents/session-tool-result-guard.ts @@ -101,6 +101,9 @@ export function installSessionToolResultGuard( if (role === "assistant") { const sanitized = sanitizeToolCallInputs([message]); if (sanitized.length === 0) { + if (allowSyntheticToolResults && pending.size > 0) { + flushPendingToolResults(); + } return undefined; } nextMessage = sanitized[0];