diff --git a/server/cmd/daemon/daemon.go b/server/cmd/daemon/daemon.go index ddc6e0a5..71f1c864 100644 --- a/server/cmd/daemon/daemon.go +++ b/server/cmd/daemon/daemon.go @@ -384,6 +384,7 @@ func (d *daemon) heartbeatLoop(ctx context.Context, runtimeIDs []string) { } func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error { + pollOffset := 0 for { select { case <-ctx.Done(): @@ -392,7 +393,9 @@ func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error { } claimed := false - for _, rid := range runtimeIDs { + n := len(runtimeIDs) + for i := 0; i < n; i++ { + rid := runtimeIDs[(pollOffset+i)%n] task, err := d.client.claimTask(ctx, rid) if err != nil { d.logger.Printf("claim task failed for runtime %s: %v", rid, err) @@ -402,11 +405,13 @@ func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error { d.logger.Printf("poll: got task=%s issue=%s title=%q", task.ID, task.IssueID, task.Context.Issue.Title) d.handleTask(ctx, *task) claimed = true + pollOffset = (pollOffset + i + 1) % n break } } if !claimed { + pollOffset = (pollOffset + 1) % n if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil { return err } diff --git a/server/pkg/agent/agent_test.go b/server/pkg/agent/agent_test.go new file mode 100644 index 00000000..7ddec759 --- /dev/null +++ b/server/pkg/agent/agent_test.go @@ -0,0 +1,53 @@ +package agent + +import ( + "context" + "testing" +) + +func TestNewReturnsClaudeBackend(t *testing.T) { + t.Parallel() + b, err := New("claude", Config{ExecutablePath: "/nonexistent/claude"}) + if err != nil { + t.Fatalf("New(claude) error: %v", err) + } + if _, ok := b.(*claudeBackend); !ok { + t.Fatalf("expected *claudeBackend, got %T", b) + } +} + +func TestNewReturnsCodexBackend(t *testing.T) { + t.Parallel() + b, err := New("codex", Config{ExecutablePath: "/nonexistent/codex"}) + if err != nil { + t.Fatalf("New(codex) error: %v", err) + } + if _, ok := b.(*codexBackend); !ok { + t.Fatalf("expected *codexBackend, got %T", b) + } +} + +func TestNewRejectsUnknownType(t *testing.T) { + t.Parallel() + _, err := New("gpt", Config{}) + if err == nil { + t.Fatal("expected error for unknown agent type") + } +} + +func TestNewDefaultsLogger(t *testing.T) { + t.Parallel() + b, _ := New("claude", Config{}) + cb := b.(*claudeBackend) + if cb.cfg.Logger == nil { + t.Fatal("expected non-nil logger") + } +} + +func TestDetectVersionFailsForMissingBinary(t *testing.T) { + t.Parallel() + _, err := DetectVersion(context.Background(), "/nonexistent/binary") + if err == nil { + t.Fatal("expected error for missing binary") + } +} diff --git a/server/pkg/agent/claude.go b/server/pkg/agent/claude.go index cb735936..1093b739 100644 --- a/server/pkg/agent/claude.go +++ b/server/pkg/agent/claude.go @@ -244,10 +244,13 @@ func (b *claudeBackend) handleControlRequest(msg claudeSDKMessage, stdin interfa data, err := json.Marshal(response) if err != nil { + b.cfg.Logger.Printf("[claude] failed to marshal control response: %v", err) return } data = append(data, '\n') - _, _ = stdin.Write(data) + if _, err := stdin.Write(data); err != nil { + b.cfg.Logger.Printf("[claude] failed to write control response: %v", err) + } } // ── Claude SDK JSON types ── diff --git a/server/pkg/agent/claude_test.go b/server/pkg/agent/claude_test.go new file mode 100644 index 00000000..8018a2bf --- /dev/null +++ b/server/pkg/agent/claude_test.go @@ -0,0 +1,229 @@ +package agent + +import ( + "bytes" + "encoding/json" + "log" + "strings" + "testing" +) + +func TestClaudeHandleAssistantText(t *testing.T) { + t.Parallel() + + b := &claudeBackend{cfg: Config{Logger: log.Default()}} + ch := make(chan Message, 10) + var output strings.Builder + + msg := claudeSDKMessage{ + Type: "assistant", + Message: mustMarshal(t, claudeMessageContent{ + Role: "assistant", + Content: []claudeContentBlock{ + {Type: "text", Text: "Hello world"}, + }, + }), + } + + b.handleAssistant(msg, ch, &output) + + if output.String() != "Hello world" { + t.Fatalf("expected output 'Hello world', got %q", output.String()) + } + select { + case m := <-ch: + if m.Type != MessageText || m.Content != "Hello world" { + t.Fatalf("unexpected message: %+v", m) + } + default: + t.Fatal("expected message on channel") + } +} + +func TestClaudeHandleAssistantToolUse(t *testing.T) { + t.Parallel() + + b := &claudeBackend{cfg: Config{Logger: log.Default()}} + ch := make(chan Message, 10) + var output strings.Builder + + msg := claudeSDKMessage{ + Type: "assistant", + Message: mustMarshal(t, claudeMessageContent{ + Role: "assistant", + Content: []claudeContentBlock{ + { + Type: "tool_use", + ID: "call-1", + Name: "Read", + Input: mustMarshal(t, map[string]any{"path": "/tmp/foo"}), + }, + }, + }), + } + + b.handleAssistant(msg, ch, &output) + + if output.String() != "" { + t.Fatalf("tool_use should not add to output, got %q", output.String()) + } + select { + case m := <-ch: + if m.Type != MessageToolUse || m.Tool != "Read" || m.CallID != "call-1" { + t.Fatalf("unexpected message: %+v", m) + } + if m.Input["path"] != "/tmp/foo" { + t.Fatalf("expected input path /tmp/foo, got %v", m.Input["path"]) + } + default: + t.Fatal("expected message on channel") + } +} + +func TestClaudeHandleUserToolResult(t *testing.T) { + t.Parallel() + + b := &claudeBackend{cfg: Config{Logger: log.Default()}} + ch := make(chan Message, 10) + + msg := claudeSDKMessage{ + Type: "user", + Message: mustMarshal(t, claudeMessageContent{ + Role: "user", + Content: []claudeContentBlock{ + { + Type: "tool_result", + ToolUseID: "call-1", + Content: mustMarshal(t, "file contents here"), + }, + }, + }), + } + + b.handleUser(msg, ch) + + select { + case m := <-ch: + if m.Type != MessageToolResult || m.CallID != "call-1" { + t.Fatalf("unexpected message: %+v", m) + } + default: + t.Fatal("expected message on channel") + } +} + +func TestClaudeHandleControlRequestAutoApproves(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + b := &claudeBackend{cfg: Config{Logger: log.New(&buf, "", 0)}} + + var written bytes.Buffer + + msg := claudeSDKMessage{ + Type: "control_request", + RequestID: "req-42", + Request: mustMarshal(t, claudeControlRequestPayload{ + Subtype: "tool_use", + ToolName: "Bash", + Input: mustMarshal(t, map[string]any{"command": "ls"}), + }), + } + + b.handleControlRequest(msg, &written) + + var resp map[string]any + if err := json.Unmarshal(bytes.TrimSpace(written.Bytes()), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if resp["type"] != "control_response" { + t.Fatalf("expected type control_response, got %v", resp["type"]) + } + respInner := resp["response"].(map[string]any) + if respInner["request_id"] != "req-42" { + t.Fatalf("expected request_id req-42, got %v", respInner["request_id"]) + } + innerResp := respInner["response"].(map[string]any) + if innerResp["behavior"] != "allow" { + t.Fatalf("expected behavior allow, got %v", innerResp["behavior"]) + } +} + +func TestClaudeHandleAssistantInvalidJSON(t *testing.T) { + t.Parallel() + + b := &claudeBackend{cfg: Config{Logger: log.Default()}} + ch := make(chan Message, 10) + var output strings.Builder + + msg := claudeSDKMessage{ + Type: "assistant", + Message: json.RawMessage(`invalid json`), + } + + // Should not panic + b.handleAssistant(msg, ch, &output) + + if output.String() != "" { + t.Fatalf("expected empty output for invalid JSON, got %q", output.String()) + } + select { + case m := <-ch: + t.Fatalf("expected no message, got %+v", m) + default: + } +} + +func TestTrySendDropsWhenFull(t *testing.T) { + t.Parallel() + + ch := make(chan Message, 1) + // Fill the channel + trySend(ch, Message{Type: MessageText, Content: "first"}) + // This should not block + trySend(ch, Message{Type: MessageText, Content: "second"}) + + m := <-ch + if m.Content != "first" { + t.Fatalf("expected 'first', got %q", m.Content) + } + select { + case m := <-ch: + t.Fatalf("expected empty channel, got %+v", m) + default: + } +} + +func TestBuildEnvAppendsExtras(t *testing.T) { + t.Parallel() + + env := buildEnv(map[string]string{"FOO": "bar", "BAZ": "qux"}) + found := 0 + for _, e := range env { + if e == "FOO=bar" || e == "BAZ=qux" { + found++ + } + } + if found != 2 { + t.Fatalf("expected 2 extra env vars, found %d", found) + } +} + +func TestBuildEnvNilExtras(t *testing.T) { + t.Parallel() + + env := buildEnv(nil) + if len(env) == 0 { + t.Fatal("expected at least system env vars") + } +} + +func mustMarshal(t *testing.T, v any) json.RawMessage { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + return data +} diff --git a/server/pkg/agent/codex.go b/server/pkg/agent/codex.go index 6028cc69..97a67a89 100644 --- a/server/pkg/agent/codex.go +++ b/server/pkg/agent/codex.go @@ -60,23 +60,38 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti msgCh := make(chan Message, 256) resCh := make(chan Result, 1) + var outputMu sync.Mutex var output strings.Builder + // turnDone is set before starting the reader goroutine so there is no + // race between the lifecycle goroutine writing and the reader reading. + turnDone := make(chan bool, 1) // true = aborted + c := &codexClient{ - cfg: b.cfg, - stdin: stdin, - pending: make(map[int]*pendingRPC), - // Set onMessage before starting the reader goroutine to avoid a race. + cfg: b.cfg, + stdin: stdin, + pending: make(map[int]*pendingRPC), + notificationProtocol: "unknown", onMessage: func(msg Message) { if msg.Type == MessageText { + outputMu.Lock() output.WriteString(msg.Content) + outputMu.Unlock() } trySend(msgCh, msg) }, + onTurnDone: func(aborted bool) { + select { + case turnDone <- aborted: + default: + } + }, } // Start reading stdout in background + readerDone := make(chan struct{}) go func() { + defer close(readerDone) scanner := bufio.NewScanner(stdout) scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) for scanner.Scan() { @@ -156,14 +171,6 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti b.cfg.Logger.Printf("[codex] thread started: %s", threadID) // 3. Send turn and wait for completion - turnDone := make(chan bool, 1) // true = aborted - c.onTurnDone = func(aborted bool) { - select { - case turnDone <- aborted: - default: - } - } - _, err = c.request(runCtx, "turn/start", map[string]any{ "threadId": threadID, "input": []map[string]any{ @@ -198,9 +205,16 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti b.cfg.Logger.Printf("[codex] finished pid=%d status=%s duration=%s", cmd.Process.Pid, finalStatus, duration.Round(time.Millisecond)) + // Wait for the reader goroutine to finish so all output is accumulated. + <-readerDone + + outputMu.Lock() + finalOutput := output.String() + outputMu.Unlock() + resCh <- Result{ Status: finalStatus, - Output: output.String(), + Output: finalOutput, Error: finalError, DurationMs: duration.Milliseconds(), } diff --git a/server/pkg/agent/codex_test.go b/server/pkg/agent/codex_test.go new file mode 100644 index 00000000..dc3f64b1 --- /dev/null +++ b/server/pkg/agent/codex_test.go @@ -0,0 +1,545 @@ +package agent + +import ( + "encoding/json" + "fmt" + "log" + "sync" + "testing" +) + +func newTestCodexClient(t *testing.T) (*codexClient, *fakeStdin, []Message) { + t.Helper() + fs := &fakeStdin{} + var mu sync.Mutex + var messages []Message + + c := &codexClient{ + cfg: Config{Logger: log.Default()}, + stdin: fs, + pending: make(map[int]*pendingRPC), + onMessage: func(msg Message) { + mu.Lock() + messages = append(messages, msg) + mu.Unlock() + }, + onTurnDone: func(aborted bool) {}, + } + return c, fs, messages +} + +type fakeStdin struct { + mu sync.Mutex + data []byte +} + +func (f *fakeStdin) Write(p []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.data = append(f.data, p...) + return len(p), nil +} + +func (f *fakeStdin) Lines() []string { + f.mu.Lock() + defer f.mu.Unlock() + var lines []string + for _, line := range splitLines(string(f.data)) { + if line != "" { + lines = append(lines, line) + } + } + return lines +} + +func splitLines(s string) []string { + var lines []string + start := 0 + for i, c := range s { + if c == '\n' { + lines = append(lines, s[start:i]) + start = i + 1 + } + } + if start < len(s) { + lines = append(lines, s[start:]) + } + return lines +} + +func TestCodexHandleResponseSuccess(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + + // Register a pending request + pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"} + c.mu.Lock() + c.pending[1] = pr + c.mu.Unlock() + + c.handleLine(`{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`) + + res := <-pr.ch + if res.err != nil { + t.Fatalf("expected no error, got %v", res.err) + } + + var parsed map[string]any + if err := json.Unmarshal(res.result, &parsed); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if parsed["ok"] != true { + t.Fatalf("expected ok=true, got %v", parsed["ok"]) + } +} + +func TestCodexHandleResponseError(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + + pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"} + c.mu.Lock() + c.pending[1] = pr + c.mu.Unlock() + + c.handleLine(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"bad request"}}`) + + res := <-pr.ch + if res.err == nil { + t.Fatal("expected error") + } + if res.result != nil { + t.Fatalf("expected nil result, got %v", res.result) + } +} + +func TestCodexHandleServerRequestAutoApproves(t *testing.T) { + t.Parallel() + + c, fs, _ := newTestCodexClient(t) + + // Command execution approval + c.handleLine(`{"jsonrpc":"2.0","id":10,"method":"item/commandExecution/requestApproval","params":{}}`) + + lines := fs.Lines() + if len(lines) != 1 { + t.Fatalf("expected 1 response, got %d", len(lines)) + } + + var resp map[string]any + if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp["id"] != float64(10) { + t.Fatalf("expected id=10, got %v", resp["id"]) + } + result := resp["result"].(map[string]any) + if result["decision"] != "accept" { + t.Fatalf("expected decision=accept, got %v", result["decision"]) + } +} + +func TestCodexHandleServerRequestFileChangeApproval(t *testing.T) { + t.Parallel() + + c, fs, _ := newTestCodexClient(t) + + c.handleLine(`{"jsonrpc":"2.0","id":11,"method":"applyPatchApproval","params":{}}`) + + lines := fs.Lines() + if len(lines) != 1 { + t.Fatalf("expected 1 response, got %d", len(lines)) + } + + var resp map[string]any + if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + result := resp["result"].(map[string]any) + if result["decision"] != "accept" { + t.Fatalf("expected decision=accept, got %v", result["decision"]) + } +} + +func TestCodexLegacyEventTaskStarted(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + var gotStatus bool + c.onMessage = func(msg Message) { + if msg.Type == MessageStatus && msg.Status == "running" { + gotStatus = true + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`) + + if !gotStatus { + t.Fatal("expected status=running message") + } + if !c.turnStarted { + t.Fatal("expected turnStarted=true") + } + if c.notificationProtocol != "legacy" { + t.Fatalf("expected protocol=legacy, got %q", c.notificationProtocol) + } +} + +func TestCodexLegacyEventAgentMessage(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + var gotText string + c.onMessage = func(msg Message) { + if msg.Type == MessageText { + gotText = msg.Content + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"agent_message","message":"I found the bug"}}}`) + + if gotText != "I found the bug" { + t.Fatalf("expected text 'I found the bug', got %q", gotText) + } +} + +func TestCodexLegacyEventExecCommand(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + var messages []Message + c.onMessage = func(msg Message) { + messages = append(messages, msg) + } + + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_begin","call_id":"c1","command":"ls -la"}}}`) + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_end","call_id":"c1","output":"total 42"}}}`) + + if len(messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(messages)) + } + if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].CallID != "c1" { + t.Fatalf("unexpected begin message: %+v", messages[0]) + } + if messages[1].Type != MessageToolResult || messages[1].CallID != "c1" || messages[1].Output != "total 42" { + t.Fatalf("unexpected end message: %+v", messages[1]) + } +} + +func TestCodexLegacyEventTaskComplete(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + var done bool + c.onTurnDone = func(aborted bool) { + done = true + if aborted { + t.Fatal("expected aborted=false") + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_complete"}}}`) + + if !done { + t.Fatal("expected onTurnDone to be called") + } +} + +func TestCodexLegacyEventTurnAborted(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + var abortedResult bool + c.onTurnDone = func(aborted bool) { + abortedResult = aborted + } + + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"turn_aborted"}}}`) + + if !abortedResult { + t.Fatal("expected aborted=true") + } +} + +func TestCodexRawTurnStarted(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + // The zero value "" doesn't match "unknown", so protocol auto-detection + // won't trigger. Set it explicitly as production code would. + c.notificationProtocol = "unknown" + + var gotStatus bool + c.onMessage = func(msg Message) { + if msg.Type == MessageStatus && msg.Status == "running" { + gotStatus = true + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`) + + if !gotStatus { + t.Fatal("expected status=running message") + } + if c.notificationProtocol != "raw" { + t.Fatalf("expected protocol=raw, got %q", c.notificationProtocol) + } + if c.turnID != "turn-1" { + t.Fatalf("expected turnID=turn-1, got %q", c.turnID) + } +} + +func TestCodexRawTurnCompleted(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + + var doneCount int + c.onTurnDone = func(aborted bool) { + doneCount++ + if aborted { + t.Fatal("expected aborted=false") + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`) + + if doneCount != 1 { + t.Fatalf("expected onTurnDone called once, got %d", doneCount) + } +} + +func TestCodexRawTurnCompletedDeduplication(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + + var doneCount int + c.onTurnDone = func(aborted bool) { + doneCount++ + } + + c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`) + c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`) + + if doneCount != 1 { + t.Fatalf("expected deduplication, but onTurnDone called %d times", doneCount) + } +} + +func TestCodexRawTurnCompletedAborted(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + + var wasAborted bool + c.onTurnDone = func(aborted bool) { + wasAborted = aborted + } + + c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-2","status":"cancelled"}}}`) + + if !wasAborted { + t.Fatal("expected aborted=true for cancelled status") + } +} + +func TestCodexRawItemCommandExecution(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + + var messages []Message + c.onMessage = func(msg Message) { + messages = append(messages, msg) + } + + c.handleLine(`{"jsonrpc":"2.0","method":"item/started","params":{"item":{"type":"commandExecution","id":"item-1","command":"git status"}}}`) + c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"commandExecution","id":"item-1","aggregatedOutput":"on branch main"}}}`) + + if len(messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(messages)) + } + if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].Input["command"] != "git status" { + t.Fatalf("unexpected start message: %+v", messages[0]) + } + if messages[1].Type != MessageToolResult || messages[1].Output != "on branch main" { + t.Fatalf("unexpected complete message: %+v", messages[1]) + } +} + +func TestCodexRawItemAgentMessageFinalAnswer(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + c.turnStarted = true + + var gotText string + var turnDone bool + c.onMessage = func(msg Message) { + if msg.Type == MessageText { + gotText = msg.Content + } + } + c.onTurnDone = func(aborted bool) { + turnDone = true + } + + c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"agentMessage","id":"msg-1","text":"Done!","phase":"final_answer"}}}`) + + if gotText != "Done!" { + t.Fatalf("expected text 'Done!', got %q", gotText) + } + if !turnDone { + t.Fatal("expected onTurnDone for final_answer") + } +} + +func TestCodexRawThreadStatusIdle(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + c.notificationProtocol = "raw" + c.turnStarted = true + + var turnDone bool + c.onTurnDone = func(aborted bool) { + turnDone = true + if aborted { + t.Fatal("expected aborted=false for idle") + } + } + + c.handleLine(`{"jsonrpc":"2.0","method":"thread/status/changed","params":{"status":{"type":"idle"}}}`) + + if !turnDone { + t.Fatal("expected onTurnDone for idle status") + } +} + +func TestCodexCloseAllPending(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + + pr1 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m1"} + pr2 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m2"} + c.mu.Lock() + c.pending[1] = pr1 + c.pending[2] = pr2 + c.mu.Unlock() + + c.closeAllPending(fmt.Errorf("test error")) + + r1 := <-pr1.ch + if r1.err == nil { + t.Fatal("expected error for pending 1") + } + r2 := <-pr2.ch + if r2.err == nil { + t.Fatal("expected error for pending 2") + } + + c.mu.Lock() + defer c.mu.Unlock() + if len(c.pending) != 0 { + t.Fatalf("expected empty pending map, got %d", len(c.pending)) + } +} + +func TestCodexHandleInvalidJSON(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + // Should not panic + c.handleLine("not json at all") + c.handleLine("") + c.handleLine("{}") +} + +func TestExtractThreadID(t *testing.T) { + t.Parallel() + + data := json.RawMessage(`{"thread":{"id":"t-123"}}`) + got := extractThreadID(data) + if got != "t-123" { + t.Fatalf("expected t-123, got %q", got) + } +} + +func TestExtractThreadIDMissing(t *testing.T) { + t.Parallel() + + got := extractThreadID(json.RawMessage(`{}`)) + if got != "" { + t.Fatalf("expected empty, got %q", got) + } +} + +func TestExtractNestedString(t *testing.T) { + t.Parallel() + + m := map[string]any{ + "a": map[string]any{ + "b": "value", + }, + } + got := extractNestedString(m, "a", "b") + if got != "value" { + t.Fatalf("expected 'value', got %q", got) + } +} + +func TestExtractNestedStringMissingKey(t *testing.T) { + t.Parallel() + + m := map[string]any{"a": "flat"} + got := extractNestedString(m, "a", "b") + if got != "" { + t.Fatalf("expected empty, got %q", got) + } +} + +func TestNilIfEmpty(t *testing.T) { + t.Parallel() + + if nilIfEmpty("") != nil { + t.Fatal("expected nil for empty string") + } + if nilIfEmpty("hello") != "hello" { + t.Fatal("expected 'hello'") + } +} + +func TestCodexProtocolDetectionLegacyBlocksRaw(t *testing.T) { + t.Parallel() + + c, _, _ := newTestCodexClient(t) + + var messages []Message + c.onMessage = func(msg Message) { + messages = append(messages, msg) + } + + // First: receive a legacy event -> locks to "legacy" + c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`) + + if c.notificationProtocol != "legacy" { + t.Fatalf("expected legacy, got %q", c.notificationProtocol) + } + + // Now send a raw notification -> should be ignored + messagesBefore := len(messages) + c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`) + + if len(messages) != messagesBefore { + t.Fatal("raw notification should be ignored in legacy mode") + } +}