fix(agent): fix data races, add tests, and fix raw protocol detection
- Fix data race on output strings.Builder in codex backend by adding mutex and waiting for reader goroutine before reading final output - Fix data race on onTurnDone by initializing it before reader starts - Fix bug where notificationProtocol zero value "" never matched "unknown", silently dropping all raw v2 notifications from codex - Add round-robin polling to prevent runtime starvation in poll loop - Log errors in claude handleControlRequest instead of silently dropping - Add 35 tests for pkg/agent covering claude parsing, codex JSON-RPC, protocol detection, event handling, and helper functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0d9b687d92
commit
96cfdc2e27
6 changed files with 864 additions and 15 deletions
545
server/pkg/agent/codex_test.go
Normal file
545
server/pkg/agent/codex_test.go
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestCodexClient(t *testing.T) (*codexClient, *fakeStdin, []Message) {
|
||||
t.Helper()
|
||||
fs := &fakeStdin{}
|
||||
var mu sync.Mutex
|
||||
var messages []Message
|
||||
|
||||
c := &codexClient{
|
||||
cfg: Config{Logger: log.Default()},
|
||||
stdin: fs,
|
||||
pending: make(map[int]*pendingRPC),
|
||||
onMessage: func(msg Message) {
|
||||
mu.Lock()
|
||||
messages = append(messages, msg)
|
||||
mu.Unlock()
|
||||
},
|
||||
onTurnDone: func(aborted bool) {},
|
||||
}
|
||||
return c, fs, messages
|
||||
}
|
||||
|
||||
type fakeStdin struct {
|
||||
mu sync.Mutex
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (f *fakeStdin) Write(p []byte) (int, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.data = append(f.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (f *fakeStdin) Lines() []string {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
var lines []string
|
||||
for _, line := range splitLines(string(f.data)) {
|
||||
if line != "" {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
start := 0
|
||||
for i, c := range s {
|
||||
if c == '\n' {
|
||||
lines = append(lines, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
lines = append(lines, s[start:])
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func TestCodexHandleResponseSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
// Register a pending request
|
||||
pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr
|
||||
c.mu.Unlock()
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`)
|
||||
|
||||
res := <-pr.ch
|
||||
if res.err != nil {
|
||||
t.Fatalf("expected no error, got %v", res.err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(res.result, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result: %v", err)
|
||||
}
|
||||
if parsed["ok"] != true {
|
||||
t.Fatalf("expected ok=true, got %v", parsed["ok"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleResponseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr
|
||||
c.mu.Unlock()
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"bad request"}}`)
|
||||
|
||||
res := <-pr.ch
|
||||
if res.err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if res.result != nil {
|
||||
t.Fatalf("expected nil result, got %v", res.result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleServerRequestAutoApproves(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, fs, _ := newTestCodexClient(t)
|
||||
|
||||
// Command execution approval
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":10,"method":"item/commandExecution/requestApproval","params":{}}`)
|
||||
|
||||
lines := fs.Lines()
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected 1 response, got %d", len(lines))
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp["id"] != float64(10) {
|
||||
t.Fatalf("expected id=10, got %v", resp["id"])
|
||||
}
|
||||
result := resp["result"].(map[string]any)
|
||||
if result["decision"] != "accept" {
|
||||
t.Fatalf("expected decision=accept, got %v", result["decision"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleServerRequestFileChangeApproval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, fs, _ := newTestCodexClient(t)
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":11,"method":"applyPatchApproval","params":{}}`)
|
||||
|
||||
lines := fs.Lines()
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected 1 response, got %d", len(lines))
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
result := resp["result"].(map[string]any)
|
||||
if result["decision"] != "accept" {
|
||||
t.Fatalf("expected decision=accept, got %v", result["decision"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTaskStarted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var gotStatus bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageStatus && msg.Status == "running" {
|
||||
gotStatus = true
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`)
|
||||
|
||||
if !gotStatus {
|
||||
t.Fatal("expected status=running message")
|
||||
}
|
||||
if !c.turnStarted {
|
||||
t.Fatal("expected turnStarted=true")
|
||||
}
|
||||
if c.notificationProtocol != "legacy" {
|
||||
t.Fatalf("expected protocol=legacy, got %q", c.notificationProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventAgentMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var gotText string
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageText {
|
||||
gotText = msg.Content
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"agent_message","message":"I found the bug"}}}`)
|
||||
|
||||
if gotText != "I found the bug" {
|
||||
t.Fatalf("expected text 'I found the bug', got %q", gotText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventExecCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_begin","call_id":"c1","command":"ls -la"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_end","call_id":"c1","output":"total 42"}}}`)
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].CallID != "c1" {
|
||||
t.Fatalf("unexpected begin message: %+v", messages[0])
|
||||
}
|
||||
if messages[1].Type != MessageToolResult || messages[1].CallID != "c1" || messages[1].Output != "total 42" {
|
||||
t.Fatalf("unexpected end message: %+v", messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTaskComplete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var done bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
done = true
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_complete"}}}`)
|
||||
|
||||
if !done {
|
||||
t.Fatal("expected onTurnDone to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTurnAborted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var abortedResult bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
abortedResult = aborted
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"turn_aborted"}}}`)
|
||||
|
||||
if !abortedResult {
|
||||
t.Fatal("expected aborted=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnStarted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
// The zero value "" doesn't match "unknown", so protocol auto-detection
|
||||
// won't trigger. Set it explicitly as production code would.
|
||||
c.notificationProtocol = "unknown"
|
||||
|
||||
var gotStatus bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageStatus && msg.Status == "running" {
|
||||
gotStatus = true
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`)
|
||||
|
||||
if !gotStatus {
|
||||
t.Fatal("expected status=running message")
|
||||
}
|
||||
if c.notificationProtocol != "raw" {
|
||||
t.Fatalf("expected protocol=raw, got %q", c.notificationProtocol)
|
||||
}
|
||||
if c.turnID != "turn-1" {
|
||||
t.Fatalf("expected turnID=turn-1, got %q", c.turnID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var doneCount int
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
doneCount++
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onTurnDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompletedDeduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var doneCount int
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
doneCount++
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected deduplication, but onTurnDone called %d times", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompletedAborted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var wasAborted bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
wasAborted = aborted
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-2","status":"cancelled"}}}`)
|
||||
|
||||
if !wasAborted {
|
||||
t.Fatal("expected aborted=true for cancelled status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawItemCommandExecution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/started","params":{"item":{"type":"commandExecution","id":"item-1","command":"git status"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"commandExecution","id":"item-1","aggregatedOutput":"on branch main"}}}`)
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].Input["command"] != "git status" {
|
||||
t.Fatalf("unexpected start message: %+v", messages[0])
|
||||
}
|
||||
if messages[1].Type != MessageToolResult || messages[1].Output != "on branch main" {
|
||||
t.Fatalf("unexpected complete message: %+v", messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawItemAgentMessageFinalAnswer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
c.turnStarted = true
|
||||
|
||||
var gotText string
|
||||
var turnDone bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageText {
|
||||
gotText = msg.Content
|
||||
}
|
||||
}
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
turnDone = true
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"agentMessage","id":"msg-1","text":"Done!","phase":"final_answer"}}}`)
|
||||
|
||||
if gotText != "Done!" {
|
||||
t.Fatalf("expected text 'Done!', got %q", gotText)
|
||||
}
|
||||
if !turnDone {
|
||||
t.Fatal("expected onTurnDone for final_answer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawThreadStatusIdle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
c.turnStarted = true
|
||||
|
||||
var turnDone bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
turnDone = true
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false for idle")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"thread/status/changed","params":{"status":{"type":"idle"}}}`)
|
||||
|
||||
if !turnDone {
|
||||
t.Fatal("expected onTurnDone for idle status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCloseAllPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
pr1 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m1"}
|
||||
pr2 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m2"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr1
|
||||
c.pending[2] = pr2
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeAllPending(fmt.Errorf("test error"))
|
||||
|
||||
r1 := <-pr1.ch
|
||||
if r1.err == nil {
|
||||
t.Fatal("expected error for pending 1")
|
||||
}
|
||||
r2 := <-pr2.ch
|
||||
if r2.err == nil {
|
||||
t.Fatal("expected error for pending 2")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(c.pending) != 0 {
|
||||
t.Fatalf("expected empty pending map, got %d", len(c.pending))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleInvalidJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
// Should not panic
|
||||
c.handleLine("not json at all")
|
||||
c.handleLine("")
|
||||
c.handleLine("{}")
|
||||
}
|
||||
|
||||
func TestExtractThreadID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := json.RawMessage(`{"thread":{"id":"t-123"}}`)
|
||||
got := extractThreadID(data)
|
||||
if got != "t-123" {
|
||||
t.Fatalf("expected t-123, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractThreadIDMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := extractThreadID(json.RawMessage(`{}`))
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNestedString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := map[string]any{
|
||||
"a": map[string]any{
|
||||
"b": "value",
|
||||
},
|
||||
}
|
||||
got := extractNestedString(m, "a", "b")
|
||||
if got != "value" {
|
||||
t.Fatalf("expected 'value', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNestedStringMissingKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := map[string]any{"a": "flat"}
|
||||
got := extractNestedString(m, "a", "b")
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if nilIfEmpty("") != nil {
|
||||
t.Fatal("expected nil for empty string")
|
||||
}
|
||||
if nilIfEmpty("hello") != "hello" {
|
||||
t.Fatal("expected 'hello'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProtocolDetectionLegacyBlocksRaw(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
// First: receive a legacy event -> locks to "legacy"
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`)
|
||||
|
||||
if c.notificationProtocol != "legacy" {
|
||||
t.Fatalf("expected legacy, got %q", c.notificationProtocol)
|
||||
}
|
||||
|
||||
// Now send a raw notification -> should be ignored
|
||||
messagesBefore := len(messages)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`)
|
||||
|
||||
if len(messages) != messagesBefore {
|
||||
t.Fatal("raw notification should be ignored in legacy mode")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue