From 6bfe8365599d1c6c3753541f13de8a3808dcae5a Mon Sep 17 00:00:00 2001 From: Jiang Bohan Date: Thu, 5 Feb 2026 14:57:33 +0800 Subject: [PATCH] feat(tools): add keyword-based memory_search tool Implements a simple memory_search tool for searching memory files: - Searches memory.md and memory/*.md files by keyword - Returns matching lines with context (2 lines before/after) - Supports case-sensitive/insensitive search - Respects maxResults limit Tool is only available when a profile is active (has profileDir). System prompt includes memory usage guidance when tool is present. Co-Authored-By: Claude Opus 4.5 --- apps/desktop/electron/ipc/agent.ts | 2 + src/agent/runner.ts | 9 +- src/agent/system-prompt/sections.ts | 17 ++ src/agent/tools.ts | 20 +- src/agent/tools/README.zh-CN.md | 26 +-- src/agent/tools/groups.ts | 3 + src/agent/tools/memory-search.test.ts | 154 ++++++++++++++ src/agent/tools/memory-search.ts | 276 ++++++++++++++++++++++++++ 8 files changed, 490 insertions(+), 17 deletions(-) create mode 100644 src/agent/tools/memory-search.test.ts create mode 100644 src/agent/tools/memory-search.ts diff --git a/apps/desktop/electron/ipc/agent.ts b/apps/desktop/electron/ipc/agent.ts index 575a290b..385e9391 100644 --- a/apps/desktop/electron/ipc/agent.ts +++ b/apps/desktop/electron/ipc/agent.ts @@ -12,6 +12,7 @@ const TOOL_GROUPS: Record = { 'group:fs': ['read', 'write', 'edit', 'glob'], 'group:runtime': ['exec', 'process'], 'group:web': ['web_search', 'web_fetch'], + 'group:memory': ['memory_search'], 'group:subagent': ['sessions_spawn'], } @@ -20,6 +21,7 @@ const ALL_KNOWN_TOOLS = [ ...TOOL_GROUPS['group:fs'], ...TOOL_GROUPS['group:runtime'], ...TOOL_GROUPS['group:web'], + ...TOOL_GROUPS['group:memory'], ...TOOL_GROUPS['group:subagent'], ] diff --git a/src/agent/runner.ts b/src/agent/runner.ts index 00c0e3b2..606ad506 100644 --- a/src/agent/runner.ts +++ b/src/agent/runner.ts @@ -2,7 +2,7 @@ import { Agent as PiAgentCore, type AgentEvent, type AgentMessage } from "@mario import { v7 as uuidv7 } from "uuid"; import type { AgentOptions, AgentRunResult, ReasoningMode } from "./types.js"; import { createAgentOutput } from "./cli/output.js"; -import { resolveModel, resolveTools } from "./tools.js"; +import { resolveModel, resolveTools, type ResolveToolsOptions } from "./tools.js"; import { resolveApiKey, resolveApiKeyForProfile, @@ -78,7 +78,7 @@ export class Agent { private readonly contextWindowGuard: ContextWindowGuardResult; private readonly debug: boolean; private reasoningMode: ReasoningMode; - private toolsOptions: AgentOptions; + private toolsOptions: ResolveToolsOptions; private readonly originalToolsConfig?: ToolsConfig; private readonly stderr: NodeJS.WritableStream; private initialized = false; @@ -280,7 +280,10 @@ export class Agent { // Merge Profile tools config with options.tools (options takes precedence) const profileToolsConfig = this.profile?.getToolsConfig(); const mergedToolsConfig = mergeToolsConfig(profileToolsConfig, options.tools); - this.toolsOptions = mergedToolsConfig ? { ...options, tools: mergedToolsConfig } : options; + const profileDir = this.profile?.getProfileDir(); + this.toolsOptions = mergedToolsConfig + ? { ...options, tools: mergedToolsConfig, profileDir } + : { ...options, profileDir }; const tools = resolveTools(this.toolsOptions); if (this.debug) { diff --git a/src/agent/system-prompt/sections.ts b/src/agent/system-prompt/sections.ts index 03008c3a..b6bd7be8 100644 --- a/src/agent/system-prompt/sections.ts +++ b/src/agent/system-prompt/sections.ts @@ -20,6 +20,7 @@ const CORE_TOOL_SUMMARIES: Record = { process: "Manage background exec sessions", web_search: "Search the web", web_fetch: "Fetch and extract readable content from a URL", + memory_search: "Search memory files by keyword", sessions_spawn: "Spawn a sub-agent session", }; @@ -33,6 +34,7 @@ const TOOL_ORDER = [ "process", "web_search", "web_fetch", + "memory_search", "sessions_spawn", ]; @@ -208,6 +210,21 @@ export function buildConditionalToolSections( const toolSet = new Set(tools.map((t) => t.toLowerCase())); const lines: string[] = []; + // Memory tools + if (toolSet.has("memory_search")) { + lines.push( + "## Memory", + "Before answering questions about prior work, decisions, dates, people, preferences, or todos:", + "1. Use `memory_search` to find relevant entries in memory files", + "2. Use `read` to get full context from matching files", + "", + "To update memory, use `edit` on the appropriate file:", + "- `memory.md` — Long-term knowledge (decisions, preferences, important context)", + "- `memory/YYYY-MM-DD.md` — Daily logs and session notes", + "", + ); + } + // Subagent tools (full mode only — minimal agents cannot spawn) if (mode === "full" && toolSet.has("sessions_spawn")) { lines.push( diff --git a/src/agent/tools.ts b/src/agent/tools.ts index caa1e5fb..f542e65f 100644 --- a/src/agent/tools.ts +++ b/src/agent/tools.ts @@ -7,6 +7,7 @@ import { createProcessTool } from "./tools/process.js"; import { createGlobTool } from "./tools/glob.js"; import { createWebFetchTool, createWebSearchTool } from "./tools/web/index.js"; import { createSessionsSpawnTool } from "./tools/sessions-spawn.js"; +import { createMemorySearchTool } from "./tools/memory-search.js"; import { filterTools } from "./tools/policy.js"; import { isMulticaError, isRetryableError } from "../shared/errors.js"; @@ -16,6 +17,8 @@ export { resolveModel } from "./providers/index.js"; /** Options for creating tools */ export interface CreateToolsOptions { cwd: string; + /** Profile directory for memory_search tool (optional) */ + profileDir?: string | undefined; /** Whether this agent is a subagent (passed to sessions_spawn tool) */ isSubagent?: boolean | undefined; /** Session ID of the agent (passed to sessions_spawn tool) */ @@ -89,7 +92,7 @@ function wrapTool( export function createAllTools(options: CreateToolsOptions | string): AgentTool[] { // Support legacy string argument for backwards compatibility const opts: CreateToolsOptions = typeof options === "string" ? { cwd: options } : options; - const { cwd, isSubagent, sessionId } = opts; + const { cwd, profileDir, isSubagent, sessionId } = opts; const baseTools = createCodingTools(cwd).filter( (tool) => tool.name !== "bash", @@ -110,6 +113,12 @@ export function createAllTools(options: CreateToolsOptions | string): AgentTool< webSearchTool as AgentTool, ]; + // Add memory_search tool if profileDir is provided + if (profileDir) { + const memorySearchTool = createMemorySearchTool(profileDir); + tools.push(memorySearchTool as AgentTool); + } + // Add sessions_spawn tool (will be filtered by policy for subagents) const sessionsSpawnTool = createSessionsSpawnTool({ isSubagent: isSubagent ?? false, @@ -120,6 +129,12 @@ export function createAllTools(options: CreateToolsOptions | string): AgentTool< return tools; } +/** Extended options for resolveTools that includes profileDir */ +export interface ResolveToolsOptions extends AgentOptions { + /** Profile directory for memory_search tool (computed from profileId if not provided) */ + profileDir?: string | undefined; +} + /** * Resolve tools for an agent with policy filtering. * @@ -129,12 +144,13 @@ export function createAllTools(options: CreateToolsOptions | string): AgentTool< * 3. Provider-specific rules * 4. Subagent restrictions */ -export function resolveTools(options: AgentOptions): AgentTool[] { +export function resolveTools(options: ResolveToolsOptions): AgentTool[] { const cwd = options.cwd ?? process.cwd(); // Create all tools const allTools = createAllTools({ cwd, + profileDir: options.profileDir, isSubagent: options.isSubagent, sessionId: options.sessionId, }); diff --git a/src/agent/tools/README.zh-CN.md b/src/agent/tools/README.zh-CN.md index cb759660..675fc6e2 100644 --- a/src/agent/tools/README.zh-CN.md +++ b/src/agent/tools/README.zh-CN.md @@ -49,19 +49,20 @@ ## 可用工具 -| 工具 | 名称 | 描述 | -| -------------- | ---------------- | ------------------------ | -| Read | `read` | 读取文件内容 | -| Write | `write` | 写入文件内容 | -| Edit | `edit` | 编辑现有文件 | -| Glob | `glob` | 按模式查找文件 | -| Exec | `exec` | 执行 Shell 命令 | -| Process | `process` | 管理长时间运行的进程 | -| Web Fetch | `web_fetch` | 从 URL 获取并提取内容 | -| Web Search | `web_search` | 搜索网络(需要 API Key) | -| Sessions Spawn | `sessions_spawn` | 创建子 Agent 会话 | +| 工具 | 名称 | 描述 | +| -------------- | ---------------- | ------------------------------ | +| Read | `read` | 读取文件内容 | +| Write | `write` | 写入文件内容 | +| Edit | `edit` | 编辑现有文件 | +| Glob | `glob` | 按模式查找文件 | +| Exec | `exec` | 执行 Shell 命令 | +| Process | `process` | 管理长时间运行的进程 | +| Web Fetch | `web_fetch` | 从 URL 获取并提取内容 | +| Web Search | `web_search` | 搜索网络(需要 API Key) | +| Memory Search | `memory_search` | 搜索 memory 文件(需要 Profile)| +| Sessions Spawn | `sessions_spawn` | 创建子 Agent 会话 | -> **注意**: Agent 使用基于文件的 memory(`memory.md`、`memory/*.md`),通过 `read` 和 `edit` 工具操作,而非专门的 memory 工具。 +> **注意**: `memory_search` 工具通过关键词搜索 `memory.md` 和 `memory/*.md` 文件。Agent 通过 `read` 和 `edit` 工具操作 memory 文件内容。 ## 工具组 @@ -72,6 +73,7 @@ | `group:fs` | read, write, edit, glob | | `group:runtime` | exec, process | | `group:web` | web_search, web_fetch | +| `group:memory` | memory_search | | `group:subagent` | sessions_spawn | | `group:core` | 所有 fs、runtime 和 web 工具 | diff --git a/src/agent/tools/groups.ts b/src/agent/tools/groups.ts index 56d793bb..821cee91 100644 --- a/src/agent/tools/groups.ts +++ b/src/agent/tools/groups.ts @@ -30,6 +30,9 @@ export const TOOL_GROUPS: Record = { // Web tools "group:web": ["web_search", "web_fetch"], + // Memory tools (requires profile) + "group:memory": ["memory_search"], + // Subagent tools "group:subagent": ["sessions_spawn"], diff --git a/src/agent/tools/memory-search.test.ts b/src/agent/tools/memory-search.test.ts new file mode 100644 index 00000000..db955bb6 --- /dev/null +++ b/src/agent/tools/memory-search.test.ts @@ -0,0 +1,154 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest"; +import { mkdirSync, writeFileSync, rmSync } from "fs"; +import { join } from "path"; +import { tmpdir } from "os"; +import { createMemorySearchTool } from "./memory-search.js"; + +describe("memory_search tool", () => { + let testDir: string; + + beforeEach(() => { + testDir = join(tmpdir(), `memory-search-test-${Date.now()}`); + mkdirSync(testDir, { recursive: true }); + }); + + afterEach(() => { + rmSync(testDir, { recursive: true, force: true }); + }); + + it("creates tool with correct name and description", () => { + const tool = createMemorySearchTool(testDir); + expect(tool.name).toBe("memory_search"); + expect(tool.label).toBe("Memory Search"); + expect(tool.description).toContain("memory files"); + }); + + it("returns no matches when no memory files exist", async () => { + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "test" }, undefined); + expect(result.details?.matches).toHaveLength(0); + expect(result.details?.filesSearched).toBe(0); + }); + + it("searches memory.md file", async () => { + // Create memory.md with test content + writeFileSync( + join(testDir, "memory.md"), + "# Memory\n\nUser prefers TypeScript over JavaScript.\n\nDecision: Use ESLint for linting.\n", + ); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "TypeScript" }, undefined); + + expect(result.details?.matches).toHaveLength(1); + expect(result.details?.matches[0]?.file).toBe("memory.md"); + expect(result.details?.matches[0]?.content).toContain("TypeScript"); + }); + + it("searches memory/*.md files", async () => { + // Create memory directory with daily logs + const memoryDir = join(testDir, "memory"); + mkdirSync(memoryDir); + writeFileSync( + join(memoryDir, "2024-01-15.md"), + "# 2024-01-15\n\nDiscussed API design with team.\n", + ); + writeFileSync( + join(memoryDir, "2024-01-16.md"), + "# 2024-01-16\n\nImplemented user authentication.\n", + ); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "API" }, undefined); + + expect(result.details?.matches).toHaveLength(1); + expect(result.details?.matches[0]?.file).toBe("memory/2024-01-15.md"); + }); + + it("searches both memory.md and memory/*.md", async () => { + // Create memory.md + writeFileSync(join(testDir, "memory.md"), "Important: Always test code.\n"); + + // Create memory directory + const memoryDir = join(testDir, "memory"); + mkdirSync(memoryDir); + writeFileSync(join(memoryDir, "2024-01-15.md"), "Remember to test before deploy.\n"); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "test" }, undefined); + + expect(result.details?.matches).toHaveLength(2); + expect(result.details?.filesSearched).toBe(2); + }); + + it("is case-insensitive by default", async () => { + writeFileSync(join(testDir, "memory.md"), "User prefers TYPESCRIPT.\n"); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "typescript" }, undefined); + + expect(result.details?.matches).toHaveLength(1); + }); + + it("supports case-sensitive search", async () => { + writeFileSync(join(testDir, "memory.md"), "User prefers TYPESCRIPT.\n"); + + const tool = createMemorySearchTool(testDir); + + // Case-sensitive search should not match + const result1 = await tool.execute( + "test-call", + { query: "typescript", caseSensitive: true }, + undefined, + ); + expect(result1.details?.matches).toHaveLength(0); + + // Case-sensitive search should match + const result2 = await tool.execute( + "test-call", + { query: "TYPESCRIPT", caseSensitive: true }, + undefined, + ); + expect(result2.details?.matches).toHaveLength(1); + }); + + it("includes context lines in results", async () => { + writeFileSync( + join(testDir, "memory.md"), + "Line 1\nLine 2\nMatch here\nLine 4\nLine 5\n", + ); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute("test-call", { query: "Match" }, undefined); + + expect(result.details?.matches).toHaveLength(1); + expect(result.details?.matches[0]?.context.before).toContain("Line 2"); + expect(result.details?.matches[0]?.context.after).toContain("Line 4"); + }); + + it("respects maxResults limit", async () => { + // Create file with multiple matches + writeFileSync( + join(testDir, "memory.md"), + "test line 1\ntest line 2\ntest line 3\ntest line 4\ntest line 5\n", + ); + + const tool = createMemorySearchTool(testDir); + const result = await tool.execute( + "test-call", + { query: "test", maxResults: 2 }, + undefined, + ); + + expect(result.details?.matches).toHaveLength(2); + expect(result.details?.totalMatches).toBe(5); + expect(result.details?.truncated).toBe(true); + }); + + it("throws error for empty query", async () => { + const tool = createMemorySearchTool(testDir); + await expect(tool.execute("test-call", { query: "" }, undefined)).rejects.toThrow( + "Query must not be empty", + ); + }); +}); diff --git a/src/agent/tools/memory-search.ts b/src/agent/tools/memory-search.ts new file mode 100644 index 00000000..666f0426 --- /dev/null +++ b/src/agent/tools/memory-search.ts @@ -0,0 +1,276 @@ +import { Type } from "@sinclair/typebox"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import * as fs from "fs/promises"; +import * as path from "path"; +import fg from "fast-glob"; + +const MemorySearchSchema = Type.Object({ + query: Type.String({ + description: "Search query - keywords or phrases to find in memory files.", + }), + maxResults: Type.Optional( + Type.Number({ + description: "Maximum number of results to return. Defaults to 10.", + minimum: 1, + maximum: 50, + }), + ), + caseSensitive: Type.Optional( + Type.Boolean({ + description: "Whether the search is case-sensitive. Defaults to false.", + }), + ), +}); + +type MemorySearchArgs = { + query: string; + maxResults?: number; + caseSensitive?: boolean; +}; + +export type MemorySearchMatch = { + file: string; + line: number; + content: string; + context: { + before: string[]; + after: string[]; + }; +}; + +export type MemorySearchResult = { + matches: MemorySearchMatch[]; + totalMatches: number; + filesSearched: number; + truncated: boolean; +}; + +const DEFAULT_MAX_RESULTS = 10; +const CONTEXT_LINES = 2; + +/** + * Create a memory_search tool for searching memory files. + * + * @param profileDir - Profile directory containing memory.md and memory/ folder + */ +export function createMemorySearchTool( + profileDir: string, +): AgentTool { + return { + name: "memory_search", + label: "Memory Search", + description: + "Search through memory files (memory.md and memory/*.md) for keywords or phrases. " + + "Use this before answering questions about prior work, decisions, dates, people, preferences, or todos. " + + "Returns matching lines with context.", + parameters: MemorySearchSchema, + execute: async (_toolCallId, args, _signal) => { + const { query, maxResults, caseSensitive } = args as MemorySearchArgs; + + if (!query || query.trim() === "") { + throw new Error("Query must not be empty"); + } + + const limit = Math.min(maxResults || DEFAULT_MAX_RESULTS, 50); + const searchQuery = caseSensitive ? query : query.toLowerCase(); + + // Find all memory files + const memoryFiles = await findMemoryFiles(profileDir); + + if (memoryFiles.length === 0) { + return { + content: [{ type: "text", text: "No memory files found." }], + details: { + matches: [], + totalMatches: 0, + filesSearched: 0, + truncated: false, + }, + }; + } + + // Search each file + const allMatches: MemorySearchMatch[] = []; + + for (const file of memoryFiles) { + const matches = await searchFile(file, searchQuery, caseSensitive ?? false, profileDir); + allMatches.push(...matches); + } + + // Sort by relevance (files with more matches first, then by line number) + allMatches.sort((a, b) => { + if (a.file !== b.file) { + // Count matches per file + const aCount = allMatches.filter((m) => m.file === a.file).length; + const bCount = allMatches.filter((m) => m.file === b.file).length; + return bCount - aCount; + } + return a.line - b.line; + }); + + const totalMatches = allMatches.length; + const truncated = allMatches.length > limit; + const limitedMatches = allMatches.slice(0, limit); + + // Format output + const output = formatSearchResults(limitedMatches, totalMatches, truncated, memoryFiles.length); + + return { + content: [{ type: "text", text: output }], + details: { + matches: limitedMatches, + totalMatches, + filesSearched: memoryFiles.length, + truncated, + }, + }; + }, + }; +} + +/** + * Find all memory files in the profile directory. + */ +async function findMemoryFiles(profileDir: string): Promise { + const files: string[] = []; + + // Check for memory.md in profile root + const memoryMd = path.join(profileDir, "memory.md"); + try { + await fs.access(memoryMd); + files.push(memoryMd); + } catch { + // File doesn't exist + } + + // Check for memory/*.md files + const memoryDir = path.join(profileDir, "memory"); + try { + await fs.access(memoryDir); + const mdFiles = await fg("*.md", { + cwd: memoryDir, + onlyFiles: true, + absolute: true, + }); + files.push(...mdFiles); + } catch { + // Directory doesn't exist + } + + return files; +} + +/** + * Search a single file for the query. + */ +async function searchFile( + filePath: string, + query: string, + caseSensitive: boolean, + profileDir: string, +): Promise { + const matches: MemorySearchMatch[] = []; + + try { + const content = await fs.readFile(filePath, "utf-8"); + const lines = content.split("\n"); + + for (let i = 0; i < lines.length; i++) { + const line = lines[i]!; + const searchLine = caseSensitive ? line : line.toLowerCase(); + + if (searchLine.includes(query)) { + // Get context lines + const beforeLines: string[] = []; + const afterLines: string[] = []; + + for (let j = Math.max(0, i - CONTEXT_LINES); j < i; j++) { + beforeLines.push(lines[j]!); + } + + for (let j = i + 1; j <= Math.min(lines.length - 1, i + CONTEXT_LINES); j++) { + afterLines.push(lines[j]!); + } + + // Get relative path for display + const relativePath = path.relative(profileDir, filePath); + + matches.push({ + file: relativePath, + line: i + 1, // 1-indexed + content: line, + context: { + before: beforeLines, + after: afterLines, + }, + }); + } + } + } catch (err) { + // Skip files that can't be read + console.error(`Failed to read ${filePath}:`, err); + } + + return matches; +} + +/** + * Format search results for display. + */ +function formatSearchResults( + matches: MemorySearchMatch[], + totalMatches: number, + truncated: boolean, + filesSearched: number, +): string { + if (matches.length === 0) { + return `No matches found in ${filesSearched} memory file(s).`; + } + + const lines: string[] = []; + lines.push(`Found ${totalMatches} match(es) in ${filesSearched} file(s):`); + + if (truncated) { + lines.push(`(Showing first ${matches.length} results)`); + } + + lines.push(""); + + // Group by file + const byFile = new Map(); + for (const match of matches) { + const existing = byFile.get(match.file) || []; + existing.push(match); + byFile.set(match.file, existing); + } + + for (const [file, fileMatches] of byFile) { + lines.push(`## ${file}`); + lines.push(""); + + for (const match of fileMatches) { + lines.push(`**Line ${match.line}:**`); + + // Show context before + if (match.context.before.length > 0) { + for (const ctx of match.context.before) { + lines.push(` ${ctx}`); + } + } + + // Show matching line (highlighted) + lines.push(`> ${match.content}`); + + // Show context after + if (match.context.after.length > 0) { + for (const ctx of match.context.after) { + lines.push(` ${ctx}`); + } + } + + lines.push(""); + } + } + + return lines.join("\n"); +}