- Fix data race on output strings.Builder in codex backend by adding mutex and waiting for reader goroutine before reading final output - Fix data race on onTurnDone by initializing it before reader starts - Fix bug where notificationProtocol zero value "" never matched "unknown", silently dropping all raw v2 notifications from codex - Add round-robin polling to prevent runtime starvation in poll loop - Log errors in claude handleControlRequest instead of silently dropping - Add 35 tests for pkg/agent covering claude parsing, codex JSON-RPC, protocol detection, event handling, and helper functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
348 lines
8.6 KiB
Go
348 lines
8.6 KiB
Go
package agent
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// claudeBackend implements Backend by spawning the Claude Code CLI
|
|
// with --output-format stream-json.
|
|
type claudeBackend struct {
|
|
cfg Config
|
|
}
|
|
|
|
func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOptions) (*Session, error) {
|
|
execPath := b.cfg.ExecutablePath
|
|
if execPath == "" {
|
|
execPath = "claude"
|
|
}
|
|
if _, err := exec.LookPath(execPath); err != nil {
|
|
return nil, fmt.Errorf("claude executable not found at %q: %w", execPath, err)
|
|
}
|
|
|
|
timeout := opts.Timeout
|
|
if timeout == 0 {
|
|
timeout = 20 * time.Minute
|
|
}
|
|
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
|
|
args := []string{
|
|
"--output-format", "stream-json",
|
|
"--verbose",
|
|
"--permission-mode", "bypassPermissions",
|
|
}
|
|
if opts.Model != "" {
|
|
args = append(args, "--model", opts.Model)
|
|
}
|
|
if opts.MaxTurns > 0 {
|
|
args = append(args, "--max-turns", fmt.Sprintf("%d", opts.MaxTurns))
|
|
}
|
|
if opts.SystemPrompt != "" {
|
|
args = append(args, "--append-system-prompt", opts.SystemPrompt)
|
|
}
|
|
args = append(args, "-p", prompt)
|
|
|
|
cmd := exec.CommandContext(runCtx, execPath, args...)
|
|
if opts.Cwd != "" {
|
|
cmd.Dir = opts.Cwd
|
|
}
|
|
cmd.Env = buildEnv(b.cfg.Env)
|
|
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("claude stdout pipe: %w", err)
|
|
}
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("claude stdin pipe: %w", err)
|
|
}
|
|
cmd.Stderr = newLogWriter(b.cfg.Logger, "[claude:stderr] ")
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("start claude: %w", err)
|
|
}
|
|
|
|
b.cfg.Logger.Printf("[claude] started pid=%d cwd=%s model=%s", cmd.Process.Pid, opts.Cwd, opts.Model)
|
|
|
|
msgCh := make(chan Message, 256)
|
|
resCh := make(chan Result, 1)
|
|
|
|
go func() {
|
|
defer cancel()
|
|
defer close(msgCh)
|
|
defer close(resCh)
|
|
defer stdin.Close()
|
|
|
|
startTime := time.Now()
|
|
var output strings.Builder
|
|
var sessionID string
|
|
finalStatus := "completed"
|
|
var finalError string
|
|
|
|
scanner := bufio.NewScanner(stdout)
|
|
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
var msg claudeSDKMessage
|
|
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "assistant":
|
|
b.handleAssistant(msg, msgCh, &output)
|
|
case "user":
|
|
b.handleUser(msg, msgCh)
|
|
case "system":
|
|
if msg.SessionID != "" {
|
|
sessionID = msg.SessionID
|
|
}
|
|
trySend(msgCh, Message{Type: MessageStatus, Status: "running"})
|
|
case "result":
|
|
sessionID = msg.SessionID
|
|
if msg.ResultText != "" {
|
|
output.Reset()
|
|
output.WriteString(msg.ResultText)
|
|
}
|
|
if msg.IsError {
|
|
finalStatus = "failed"
|
|
finalError = msg.ResultText
|
|
}
|
|
case "log":
|
|
if msg.Log != nil {
|
|
trySend(msgCh, Message{
|
|
Type: MessageLog,
|
|
Level: msg.Log.Level,
|
|
Content: msg.Log.Message,
|
|
})
|
|
}
|
|
case "control_request":
|
|
b.handleControlRequest(msg, stdin)
|
|
}
|
|
}
|
|
|
|
// Wait for process exit
|
|
exitErr := cmd.Wait()
|
|
duration := time.Since(startTime)
|
|
|
|
if runCtx.Err() == context.DeadlineExceeded {
|
|
finalStatus = "timeout"
|
|
finalError = fmt.Sprintf("claude timed out after %s", timeout)
|
|
} else if runCtx.Err() == context.Canceled {
|
|
finalStatus = "aborted"
|
|
finalError = "execution cancelled"
|
|
} else if exitErr != nil && finalStatus == "completed" {
|
|
finalStatus = "failed"
|
|
finalError = fmt.Sprintf("claude exited with error: %v", exitErr)
|
|
}
|
|
|
|
b.cfg.Logger.Printf("[claude] finished pid=%d status=%s duration=%s",
|
|
cmd.Process.Pid, finalStatus, duration.Round(time.Millisecond))
|
|
|
|
resCh <- Result{
|
|
Status: finalStatus,
|
|
Output: output.String(),
|
|
Error: finalError,
|
|
DurationMs: duration.Milliseconds(),
|
|
SessionID: sessionID,
|
|
}
|
|
}()
|
|
|
|
return &Session{Messages: msgCh, Result: resCh}, nil
|
|
}
|
|
|
|
func (b *claudeBackend) handleAssistant(msg claudeSDKMessage, ch chan<- Message, output *strings.Builder) {
|
|
var content claudeMessageContent
|
|
if err := json.Unmarshal(msg.Message, &content); err != nil {
|
|
return
|
|
}
|
|
|
|
for _, block := range content.Content {
|
|
switch block.Type {
|
|
case "text":
|
|
if block.Text != "" {
|
|
output.WriteString(block.Text)
|
|
trySend(ch, Message{Type: MessageText, Content: block.Text})
|
|
}
|
|
case "tool_use":
|
|
var input map[string]any
|
|
if block.Input != nil {
|
|
_ = json.Unmarshal(block.Input, &input)
|
|
}
|
|
trySend(ch, Message{
|
|
Type: MessageToolUse,
|
|
Tool: block.Name,
|
|
CallID: block.ID,
|
|
Input: input,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *claudeBackend) handleUser(msg claudeSDKMessage, ch chan<- Message) {
|
|
var content claudeMessageContent
|
|
if err := json.Unmarshal(msg.Message, &content); err != nil {
|
|
return
|
|
}
|
|
|
|
for _, block := range content.Content {
|
|
if block.Type == "tool_result" {
|
|
resultStr := ""
|
|
if block.Content != nil {
|
|
resultStr = string(block.Content)
|
|
}
|
|
trySend(ch, Message{
|
|
Type: MessageToolResult,
|
|
CallID: block.ToolUseID,
|
|
Output: resultStr,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *claudeBackend) handleControlRequest(msg claudeSDKMessage, stdin interface{ Write([]byte) (int, error) }) {
|
|
// Auto-approve all tool uses in autonomous/daemon mode.
|
|
var req claudeControlRequestPayload
|
|
if err := json.Unmarshal(msg.Request, &req); err != nil {
|
|
return
|
|
}
|
|
|
|
var inputMap map[string]any
|
|
if req.Input != nil {
|
|
_ = json.Unmarshal(req.Input, &inputMap)
|
|
}
|
|
if inputMap == nil {
|
|
inputMap = map[string]any{}
|
|
}
|
|
|
|
response := map[string]any{
|
|
"type": "control_response",
|
|
"response": map[string]any{
|
|
"subtype": "success",
|
|
"request_id": msg.RequestID,
|
|
"response": map[string]any{
|
|
"behavior": "allow",
|
|
"updatedInput": inputMap,
|
|
},
|
|
},
|
|
}
|
|
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
b.cfg.Logger.Printf("[claude] failed to marshal control response: %v", err)
|
|
return
|
|
}
|
|
data = append(data, '\n')
|
|
if _, err := stdin.Write(data); err != nil {
|
|
b.cfg.Logger.Printf("[claude] failed to write control response: %v", err)
|
|
}
|
|
}
|
|
|
|
// ── Claude SDK JSON types ──
|
|
|
|
type claudeSDKMessage struct {
|
|
Type string `json:"type"`
|
|
Message json.RawMessage `json:"message,omitempty"`
|
|
Subtype string `json:"subtype,omitempty"`
|
|
SessionID string `json:"session_id,omitempty"`
|
|
|
|
// result fields
|
|
ResultText string `json:"result,omitempty"`
|
|
IsError bool `json:"is_error,omitempty"`
|
|
DurationMs float64 `json:"duration_ms,omitempty"`
|
|
NumTurns int `json:"num_turns,omitempty"`
|
|
|
|
// log fields
|
|
Log *claudeLogEntry `json:"log,omitempty"`
|
|
|
|
// control request fields
|
|
RequestID string `json:"request_id,omitempty"`
|
|
Request json.RawMessage `json:"request,omitempty"`
|
|
}
|
|
|
|
type claudeLogEntry struct {
|
|
Level string `json:"level"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
type claudeMessageContent struct {
|
|
Role string `json:"role"`
|
|
Content []claudeContentBlock `json:"content"`
|
|
}
|
|
|
|
type claudeContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Input json.RawMessage `json:"input,omitempty"`
|
|
ToolUseID string `json:"tool_use_id,omitempty"`
|
|
Content json.RawMessage `json:"content,omitempty"`
|
|
}
|
|
|
|
type claudeControlRequestPayload struct {
|
|
Subtype string `json:"subtype"`
|
|
ToolName string `json:"tool_name,omitempty"`
|
|
Input json.RawMessage `json:"input,omitempty"`
|
|
}
|
|
|
|
// ── Shared helpers ──
|
|
|
|
func trySend(ch chan<- Message, msg Message) {
|
|
select {
|
|
case ch <- msg:
|
|
default:
|
|
// Channel full — drop message. Final output is accumulated separately
|
|
// in Result.Output, so only streaming consumers are affected.
|
|
}
|
|
}
|
|
|
|
func buildEnv(extra map[string]string) []string {
|
|
env := os.Environ()
|
|
for k, v := range extra {
|
|
env = append(env, k+"="+v)
|
|
}
|
|
return env
|
|
}
|
|
|
|
func detectCLIVersion(ctx context.Context, execPath string) (string, error) {
|
|
cmd := exec.CommandContext(ctx, execPath, "--version")
|
|
data, err := cmd.Output()
|
|
if err != nil {
|
|
return "", fmt.Errorf("detect version for %s: %w", execPath, err)
|
|
}
|
|
return strings.TrimSpace(string(data)), nil
|
|
}
|
|
|
|
// logWriter adapts a *log.Logger to an io.Writer for capturing stderr.
|
|
type logWriter struct {
|
|
logger *log.Logger
|
|
prefix string
|
|
}
|
|
|
|
func newLogWriter(logger *log.Logger, prefix string) *logWriter {
|
|
return &logWriter{logger: logger, prefix: prefix}
|
|
}
|
|
|
|
func (w *logWriter) Write(p []byte) (int, error) {
|
|
text := strings.TrimSpace(string(p))
|
|
if text != "" {
|
|
w.logger.Printf("%s%s", w.prefix, text)
|
|
}
|
|
return len(p), nil
|
|
}
|