fix(agent): fix data races, add tests, and fix raw protocol detection
- Fix data race on output strings.Builder in codex backend by adding mutex and waiting for reader goroutine before reading final output - Fix data race on onTurnDone by initializing it before reader starts - Fix bug where notificationProtocol zero value "" never matched "unknown", silently dropping all raw v2 notifications from codex - Add round-robin polling to prevent runtime starvation in poll loop - Log errors in claude handleControlRequest instead of silently dropping - Add 35 tests for pkg/agent covering claude parsing, codex JSON-RPC, protocol detection, event handling, and helper functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0d9b687d92
commit
96cfdc2e27
6 changed files with 864 additions and 15 deletions
|
|
@ -384,6 +384,7 @@ func (d *daemon) heartbeatLoop(ctx context.Context, runtimeIDs []string) {
|
|||
}
|
||||
|
||||
func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error {
|
||||
pollOffset := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -392,7 +393,9 @@ func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error {
|
|||
}
|
||||
|
||||
claimed := false
|
||||
for _, rid := range runtimeIDs {
|
||||
n := len(runtimeIDs)
|
||||
for i := 0; i < n; i++ {
|
||||
rid := runtimeIDs[(pollOffset+i)%n]
|
||||
task, err := d.client.claimTask(ctx, rid)
|
||||
if err != nil {
|
||||
d.logger.Printf("claim task failed for runtime %s: %v", rid, err)
|
||||
|
|
@ -402,11 +405,13 @@ func (d *daemon) pollLoop(ctx context.Context, runtimeIDs []string) error {
|
|||
d.logger.Printf("poll: got task=%s issue=%s title=%q", task.ID, task.IssueID, task.Context.Issue.Title)
|
||||
d.handleTask(ctx, *task)
|
||||
claimed = true
|
||||
pollOffset = (pollOffset + i + 1) % n
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !claimed {
|
||||
pollOffset = (pollOffset + 1) % n
|
||||
if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
53
server/pkg/agent/agent_test.go
Normal file
53
server/pkg/agent/agent_test.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewReturnsClaudeBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
b, err := New("claude", Config{ExecutablePath: "/nonexistent/claude"})
|
||||
if err != nil {
|
||||
t.Fatalf("New(claude) error: %v", err)
|
||||
}
|
||||
if _, ok := b.(*claudeBackend); !ok {
|
||||
t.Fatalf("expected *claudeBackend, got %T", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReturnsCodexBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
b, err := New("codex", Config{ExecutablePath: "/nonexistent/codex"})
|
||||
if err != nil {
|
||||
t.Fatalf("New(codex) error: %v", err)
|
||||
}
|
||||
if _, ok := b.(*codexBackend); !ok {
|
||||
t.Fatalf("expected *codexBackend, got %T", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRejectsUnknownType(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := New("gpt", Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown agent type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDefaultsLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
b, _ := New("claude", Config{})
|
||||
cb := b.(*claudeBackend)
|
||||
if cb.cfg.Logger == nil {
|
||||
t.Fatal("expected non-nil logger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVersionFailsForMissingBinary(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := DetectVersion(context.Background(), "/nonexistent/binary")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing binary")
|
||||
}
|
||||
}
|
||||
|
|
@ -244,10 +244,13 @@ func (b *claudeBackend) handleControlRequest(msg claudeSDKMessage, stdin interfa
|
|||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
b.cfg.Logger.Printf("[claude] failed to marshal control response: %v", err)
|
||||
return
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, _ = stdin.Write(data)
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
b.cfg.Logger.Printf("[claude] failed to write control response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude SDK JSON types ──
|
||||
|
|
|
|||
229
server/pkg/agent/claude_test.go
Normal file
229
server/pkg/agent/claude_test.go
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeHandleAssistantText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &claudeBackend{cfg: Config{Logger: log.Default()}}
|
||||
ch := make(chan Message, 10)
|
||||
var output strings.Builder
|
||||
|
||||
msg := claudeSDKMessage{
|
||||
Type: "assistant",
|
||||
Message: mustMarshal(t, claudeMessageContent{
|
||||
Role: "assistant",
|
||||
Content: []claudeContentBlock{
|
||||
{Type: "text", Text: "Hello world"},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
b.handleAssistant(msg, ch, &output)
|
||||
|
||||
if output.String() != "Hello world" {
|
||||
t.Fatalf("expected output 'Hello world', got %q", output.String())
|
||||
}
|
||||
select {
|
||||
case m := <-ch:
|
||||
if m.Type != MessageText || m.Content != "Hello world" {
|
||||
t.Fatalf("unexpected message: %+v", m)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected message on channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeHandleAssistantToolUse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &claudeBackend{cfg: Config{Logger: log.Default()}}
|
||||
ch := make(chan Message, 10)
|
||||
var output strings.Builder
|
||||
|
||||
msg := claudeSDKMessage{
|
||||
Type: "assistant",
|
||||
Message: mustMarshal(t, claudeMessageContent{
|
||||
Role: "assistant",
|
||||
Content: []claudeContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call-1",
|
||||
Name: "Read",
|
||||
Input: mustMarshal(t, map[string]any{"path": "/tmp/foo"}),
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
b.handleAssistant(msg, ch, &output)
|
||||
|
||||
if output.String() != "" {
|
||||
t.Fatalf("tool_use should not add to output, got %q", output.String())
|
||||
}
|
||||
select {
|
||||
case m := <-ch:
|
||||
if m.Type != MessageToolUse || m.Tool != "Read" || m.CallID != "call-1" {
|
||||
t.Fatalf("unexpected message: %+v", m)
|
||||
}
|
||||
if m.Input["path"] != "/tmp/foo" {
|
||||
t.Fatalf("expected input path /tmp/foo, got %v", m.Input["path"])
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected message on channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeHandleUserToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &claudeBackend{cfg: Config{Logger: log.Default()}}
|
||||
ch := make(chan Message, 10)
|
||||
|
||||
msg := claudeSDKMessage{
|
||||
Type: "user",
|
||||
Message: mustMarshal(t, claudeMessageContent{
|
||||
Role: "user",
|
||||
Content: []claudeContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "call-1",
|
||||
Content: mustMarshal(t, "file contents here"),
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
b.handleUser(msg, ch)
|
||||
|
||||
select {
|
||||
case m := <-ch:
|
||||
if m.Type != MessageToolResult || m.CallID != "call-1" {
|
||||
t.Fatalf("unexpected message: %+v", m)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected message on channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeHandleControlRequestAutoApproves(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
b := &claudeBackend{cfg: Config{Logger: log.New(&buf, "", 0)}}
|
||||
|
||||
var written bytes.Buffer
|
||||
|
||||
msg := claudeSDKMessage{
|
||||
Type: "control_request",
|
||||
RequestID: "req-42",
|
||||
Request: mustMarshal(t, claudeControlRequestPayload{
|
||||
Subtype: "tool_use",
|
||||
ToolName: "Bash",
|
||||
Input: mustMarshal(t, map[string]any{"command": "ls"}),
|
||||
}),
|
||||
}
|
||||
|
||||
b.handleControlRequest(msg, &written)
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(bytes.TrimSpace(written.Bytes()), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp["type"] != "control_response" {
|
||||
t.Fatalf("expected type control_response, got %v", resp["type"])
|
||||
}
|
||||
respInner := resp["response"].(map[string]any)
|
||||
if respInner["request_id"] != "req-42" {
|
||||
t.Fatalf("expected request_id req-42, got %v", respInner["request_id"])
|
||||
}
|
||||
innerResp := respInner["response"].(map[string]any)
|
||||
if innerResp["behavior"] != "allow" {
|
||||
t.Fatalf("expected behavior allow, got %v", innerResp["behavior"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeHandleAssistantInvalidJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &claudeBackend{cfg: Config{Logger: log.Default()}}
|
||||
ch := make(chan Message, 10)
|
||||
var output strings.Builder
|
||||
|
||||
msg := claudeSDKMessage{
|
||||
Type: "assistant",
|
||||
Message: json.RawMessage(`invalid json`),
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
b.handleAssistant(msg, ch, &output)
|
||||
|
||||
if output.String() != "" {
|
||||
t.Fatalf("expected empty output for invalid JSON, got %q", output.String())
|
||||
}
|
||||
select {
|
||||
case m := <-ch:
|
||||
t.Fatalf("expected no message, got %+v", m)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrySendDropsWhenFull(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan Message, 1)
|
||||
// Fill the channel
|
||||
trySend(ch, Message{Type: MessageText, Content: "first"})
|
||||
// This should not block
|
||||
trySend(ch, Message{Type: MessageText, Content: "second"})
|
||||
|
||||
m := <-ch
|
||||
if m.Content != "first" {
|
||||
t.Fatalf("expected 'first', got %q", m.Content)
|
||||
}
|
||||
select {
|
||||
case m := <-ch:
|
||||
t.Fatalf("expected empty channel, got %+v", m)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEnvAppendsExtras(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := buildEnv(map[string]string{"FOO": "bar", "BAZ": "qux"})
|
||||
found := 0
|
||||
for _, e := range env {
|
||||
if e == "FOO=bar" || e == "BAZ=qux" {
|
||||
found++
|
||||
}
|
||||
}
|
||||
if found != 2 {
|
||||
t.Fatalf("expected 2 extra env vars, found %d", found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEnvNilExtras(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := buildEnv(nil)
|
||||
if len(env) == 0 {
|
||||
t.Fatal("expected at least system env vars")
|
||||
}
|
||||
}
|
||||
|
||||
func mustMarshal(t *testing.T, v any) json.RawMessage {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
|
@ -60,23 +60,38 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
|
|||
msgCh := make(chan Message, 256)
|
||||
resCh := make(chan Result, 1)
|
||||
|
||||
var outputMu sync.Mutex
|
||||
var output strings.Builder
|
||||
|
||||
// turnDone is set before starting the reader goroutine so there is no
|
||||
// race between the lifecycle goroutine writing and the reader reading.
|
||||
turnDone := make(chan bool, 1) // true = aborted
|
||||
|
||||
c := &codexClient{
|
||||
cfg: b.cfg,
|
||||
stdin: stdin,
|
||||
pending: make(map[int]*pendingRPC),
|
||||
// Set onMessage before starting the reader goroutine to avoid a race.
|
||||
cfg: b.cfg,
|
||||
stdin: stdin,
|
||||
pending: make(map[int]*pendingRPC),
|
||||
notificationProtocol: "unknown",
|
||||
onMessage: func(msg Message) {
|
||||
if msg.Type == MessageText {
|
||||
outputMu.Lock()
|
||||
output.WriteString(msg.Content)
|
||||
outputMu.Unlock()
|
||||
}
|
||||
trySend(msgCh, msg)
|
||||
},
|
||||
onTurnDone: func(aborted bool) {
|
||||
select {
|
||||
case turnDone <- aborted:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Start reading stdout in background
|
||||
readerDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(readerDone)
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
|
|
@ -156,14 +171,6 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
|
|||
b.cfg.Logger.Printf("[codex] thread started: %s", threadID)
|
||||
|
||||
// 3. Send turn and wait for completion
|
||||
turnDone := make(chan bool, 1) // true = aborted
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
select {
|
||||
case turnDone <- aborted:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
_, err = c.request(runCtx, "turn/start", map[string]any{
|
||||
"threadId": threadID,
|
||||
"input": []map[string]any{
|
||||
|
|
@ -198,9 +205,16 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
|
|||
b.cfg.Logger.Printf("[codex] finished pid=%d status=%s duration=%s",
|
||||
cmd.Process.Pid, finalStatus, duration.Round(time.Millisecond))
|
||||
|
||||
// Wait for the reader goroutine to finish so all output is accumulated.
|
||||
<-readerDone
|
||||
|
||||
outputMu.Lock()
|
||||
finalOutput := output.String()
|
||||
outputMu.Unlock()
|
||||
|
||||
resCh <- Result{
|
||||
Status: finalStatus,
|
||||
Output: output.String(),
|
||||
Output: finalOutput,
|
||||
Error: finalError,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
}
|
||||
|
|
|
|||
545
server/pkg/agent/codex_test.go
Normal file
545
server/pkg/agent/codex_test.go
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestCodexClient(t *testing.T) (*codexClient, *fakeStdin, []Message) {
|
||||
t.Helper()
|
||||
fs := &fakeStdin{}
|
||||
var mu sync.Mutex
|
||||
var messages []Message
|
||||
|
||||
c := &codexClient{
|
||||
cfg: Config{Logger: log.Default()},
|
||||
stdin: fs,
|
||||
pending: make(map[int]*pendingRPC),
|
||||
onMessage: func(msg Message) {
|
||||
mu.Lock()
|
||||
messages = append(messages, msg)
|
||||
mu.Unlock()
|
||||
},
|
||||
onTurnDone: func(aborted bool) {},
|
||||
}
|
||||
return c, fs, messages
|
||||
}
|
||||
|
||||
type fakeStdin struct {
|
||||
mu sync.Mutex
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (f *fakeStdin) Write(p []byte) (int, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.data = append(f.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (f *fakeStdin) Lines() []string {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
var lines []string
|
||||
for _, line := range splitLines(string(f.data)) {
|
||||
if line != "" {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
start := 0
|
||||
for i, c := range s {
|
||||
if c == '\n' {
|
||||
lines = append(lines, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
lines = append(lines, s[start:])
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func TestCodexHandleResponseSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
// Register a pending request
|
||||
pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr
|
||||
c.mu.Unlock()
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`)
|
||||
|
||||
res := <-pr.ch
|
||||
if res.err != nil {
|
||||
t.Fatalf("expected no error, got %v", res.err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(res.result, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result: %v", err)
|
||||
}
|
||||
if parsed["ok"] != true {
|
||||
t.Fatalf("expected ok=true, got %v", parsed["ok"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleResponseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
pr := &pendingRPC{ch: make(chan rpcResult, 1), method: "test"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr
|
||||
c.mu.Unlock()
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"bad request"}}`)
|
||||
|
||||
res := <-pr.ch
|
||||
if res.err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if res.result != nil {
|
||||
t.Fatalf("expected nil result, got %v", res.result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleServerRequestAutoApproves(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, fs, _ := newTestCodexClient(t)
|
||||
|
||||
// Command execution approval
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":10,"method":"item/commandExecution/requestApproval","params":{}}`)
|
||||
|
||||
lines := fs.Lines()
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected 1 response, got %d", len(lines))
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp["id"] != float64(10) {
|
||||
t.Fatalf("expected id=10, got %v", resp["id"])
|
||||
}
|
||||
result := resp["result"].(map[string]any)
|
||||
if result["decision"] != "accept" {
|
||||
t.Fatalf("expected decision=accept, got %v", result["decision"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleServerRequestFileChangeApproval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, fs, _ := newTestCodexClient(t)
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","id":11,"method":"applyPatchApproval","params":{}}`)
|
||||
|
||||
lines := fs.Lines()
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected 1 response, got %d", len(lines))
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
result := resp["result"].(map[string]any)
|
||||
if result["decision"] != "accept" {
|
||||
t.Fatalf("expected decision=accept, got %v", result["decision"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTaskStarted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var gotStatus bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageStatus && msg.Status == "running" {
|
||||
gotStatus = true
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`)
|
||||
|
||||
if !gotStatus {
|
||||
t.Fatal("expected status=running message")
|
||||
}
|
||||
if !c.turnStarted {
|
||||
t.Fatal("expected turnStarted=true")
|
||||
}
|
||||
if c.notificationProtocol != "legacy" {
|
||||
t.Fatalf("expected protocol=legacy, got %q", c.notificationProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventAgentMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var gotText string
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageText {
|
||||
gotText = msg.Content
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"agent_message","message":"I found the bug"}}}`)
|
||||
|
||||
if gotText != "I found the bug" {
|
||||
t.Fatalf("expected text 'I found the bug', got %q", gotText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventExecCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_begin","call_id":"c1","command":"ls -la"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"exec_command_end","call_id":"c1","output":"total 42"}}}`)
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].CallID != "c1" {
|
||||
t.Fatalf("unexpected begin message: %+v", messages[0])
|
||||
}
|
||||
if messages[1].Type != MessageToolResult || messages[1].CallID != "c1" || messages[1].Output != "total 42" {
|
||||
t.Fatalf("unexpected end message: %+v", messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTaskComplete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var done bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
done = true
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_complete"}}}`)
|
||||
|
||||
if !done {
|
||||
t.Fatal("expected onTurnDone to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexLegacyEventTurnAborted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
var abortedResult bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
abortedResult = aborted
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"turn_aborted"}}}`)
|
||||
|
||||
if !abortedResult {
|
||||
t.Fatal("expected aborted=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnStarted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
// The zero value "" doesn't match "unknown", so protocol auto-detection
|
||||
// won't trigger. Set it explicitly as production code would.
|
||||
c.notificationProtocol = "unknown"
|
||||
|
||||
var gotStatus bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageStatus && msg.Status == "running" {
|
||||
gotStatus = true
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`)
|
||||
|
||||
if !gotStatus {
|
||||
t.Fatal("expected status=running message")
|
||||
}
|
||||
if c.notificationProtocol != "raw" {
|
||||
t.Fatalf("expected protocol=raw, got %q", c.notificationProtocol)
|
||||
}
|
||||
if c.turnID != "turn-1" {
|
||||
t.Fatalf("expected turnID=turn-1, got %q", c.turnID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var doneCount int
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
doneCount++
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected onTurnDone called once, got %d", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompletedDeduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var doneCount int
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
doneCount++
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-1","status":"completed"}}}`)
|
||||
|
||||
if doneCount != 1 {
|
||||
t.Fatalf("expected deduplication, but onTurnDone called %d times", doneCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawTurnCompletedAborted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var wasAborted bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
wasAborted = aborted
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/completed","params":{"turn":{"id":"turn-2","status":"cancelled"}}}`)
|
||||
|
||||
if !wasAborted {
|
||||
t.Fatal("expected aborted=true for cancelled status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawItemCommandExecution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/started","params":{"item":{"type":"commandExecution","id":"item-1","command":"git status"}}}`)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"commandExecution","id":"item-1","aggregatedOutput":"on branch main"}}}`)
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Type != MessageToolUse || messages[0].Tool != "exec_command" || messages[0].Input["command"] != "git status" {
|
||||
t.Fatalf("unexpected start message: %+v", messages[0])
|
||||
}
|
||||
if messages[1].Type != MessageToolResult || messages[1].Output != "on branch main" {
|
||||
t.Fatalf("unexpected complete message: %+v", messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawItemAgentMessageFinalAnswer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
c.turnStarted = true
|
||||
|
||||
var gotText string
|
||||
var turnDone bool
|
||||
c.onMessage = func(msg Message) {
|
||||
if msg.Type == MessageText {
|
||||
gotText = msg.Content
|
||||
}
|
||||
}
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
turnDone = true
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"item/completed","params":{"item":{"type":"agentMessage","id":"msg-1","text":"Done!","phase":"final_answer"}}}`)
|
||||
|
||||
if gotText != "Done!" {
|
||||
t.Fatalf("expected text 'Done!', got %q", gotText)
|
||||
}
|
||||
if !turnDone {
|
||||
t.Fatal("expected onTurnDone for final_answer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexRawThreadStatusIdle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
c.notificationProtocol = "raw"
|
||||
c.turnStarted = true
|
||||
|
||||
var turnDone bool
|
||||
c.onTurnDone = func(aborted bool) {
|
||||
turnDone = true
|
||||
if aborted {
|
||||
t.Fatal("expected aborted=false for idle")
|
||||
}
|
||||
}
|
||||
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"thread/status/changed","params":{"status":{"type":"idle"}}}`)
|
||||
|
||||
if !turnDone {
|
||||
t.Fatal("expected onTurnDone for idle status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCloseAllPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
pr1 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m1"}
|
||||
pr2 := &pendingRPC{ch: make(chan rpcResult, 1), method: "m2"}
|
||||
c.mu.Lock()
|
||||
c.pending[1] = pr1
|
||||
c.pending[2] = pr2
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeAllPending(fmt.Errorf("test error"))
|
||||
|
||||
r1 := <-pr1.ch
|
||||
if r1.err == nil {
|
||||
t.Fatal("expected error for pending 1")
|
||||
}
|
||||
r2 := <-pr2.ch
|
||||
if r2.err == nil {
|
||||
t.Fatal("expected error for pending 2")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(c.pending) != 0 {
|
||||
t.Fatalf("expected empty pending map, got %d", len(c.pending))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleInvalidJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
// Should not panic
|
||||
c.handleLine("not json at all")
|
||||
c.handleLine("")
|
||||
c.handleLine("{}")
|
||||
}
|
||||
|
||||
func TestExtractThreadID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := json.RawMessage(`{"thread":{"id":"t-123"}}`)
|
||||
got := extractThreadID(data)
|
||||
if got != "t-123" {
|
||||
t.Fatalf("expected t-123, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractThreadIDMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := extractThreadID(json.RawMessage(`{}`))
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNestedString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := map[string]any{
|
||||
"a": map[string]any{
|
||||
"b": "value",
|
||||
},
|
||||
}
|
||||
got := extractNestedString(m, "a", "b")
|
||||
if got != "value" {
|
||||
t.Fatalf("expected 'value', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNestedStringMissingKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := map[string]any{"a": "flat"}
|
||||
got := extractNestedString(m, "a", "b")
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if nilIfEmpty("") != nil {
|
||||
t.Fatal("expected nil for empty string")
|
||||
}
|
||||
if nilIfEmpty("hello") != "hello" {
|
||||
t.Fatal("expected 'hello'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProtocolDetectionLegacyBlocksRaw(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, _, _ := newTestCodexClient(t)
|
||||
|
||||
var messages []Message
|
||||
c.onMessage = func(msg Message) {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
// First: receive a legacy event -> locks to "legacy"
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"codex/event","params":{"msg":{"type":"task_started"}}}`)
|
||||
|
||||
if c.notificationProtocol != "legacy" {
|
||||
t.Fatalf("expected legacy, got %q", c.notificationProtocol)
|
||||
}
|
||||
|
||||
// Now send a raw notification -> should be ignored
|
||||
messagesBefore := len(messages)
|
||||
c.handleLine(`{"jsonrpc":"2.0","method":"turn/started","params":{"turn":{"id":"turn-1"}}}`)
|
||||
|
||||
if len(messages) != messagesBefore {
|
||||
t.Fatal("raw notification should be ignored in legacy mode")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue