diff --git a/server/cmd/server/router.go b/server/cmd/server/router.go index 0faa51c6..fa6587fe 100644 --- a/server/cmd/server/router.go +++ b/server/cmd/server/router.go @@ -103,6 +103,7 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route r.Post("/tasks/{taskId}/progress", h.ReportTaskProgress) r.Post("/tasks/{taskId}/complete", h.CompleteTask) r.Post("/tasks/{taskId}/fail", h.FailTask) + r.Post("/tasks/{taskId}/usage", h.ReportTaskUsage) r.Post("/tasks/{taskId}/messages", h.ReportTaskMessages) r.Get("/tasks/{taskId}/messages", h.ListTaskMessages) }) diff --git a/server/internal/daemon/client.go b/server/internal/daemon/client.go index eb55ebf4..6351cb78 100644 --- a/server/internal/daemon/client.go +++ b/server/internal/daemon/client.go @@ -99,7 +99,7 @@ func (c *Client) ReportTaskMessages(ctx context.Context, taskID string, messages }, nil) } -func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string, usage []TaskUsageEntry) error { +func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string) error { body := map[string]any{"output": output} if branchName != "" { body["branch_name"] = branchName @@ -110,12 +110,18 @@ func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, s if workDir != "" { body["work_dir"] = workDir } - if len(usage) > 0 { - body["usage"] = usage - } return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/complete", taskID), body, nil) } +func (c *Client) ReportTaskUsage(ctx context.Context, taskID string, usage []TaskUsageEntry) error { + if len(usage) == 0 { + return nil + } + return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/usage", taskID), map[string]any{ + "usage": usage, + }, nil) +} + func (c *Client) FailTask(ctx context.Context, taskID, errMsg string) error { return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/fail", taskID), map[string]any{ "error": errMsg, diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index 2470dcc9..62131d12 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -837,6 +837,13 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { return } + // Report usage independently so it's captured even for failed/blocked tasks. + if len(result.Usage) > 0 { + if err := d.client.ReportTaskUsage(ctx, task.ID, result.Usage); err != nil { + taskLog.Warn("report task usage failed", "error", err) + } + } + switch result.Status { case "blocked": if err := d.client.FailTask(ctx, task.ID, result.Comment); err != nil { @@ -844,7 +851,7 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { } default: taskLog.Info("task completed", "status", result.Status) - if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir, result.Usage); err != nil { + if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir); err != nil { taskLog.Error("complete task failed, falling back to fail", "error", err) if failErr := d.client.FailTask(ctx, task.ID, fmt.Sprintf("complete task failed: %s", err.Error())); failErr != nil { taskLog.Error("fail task fallback also failed", "error", failErr) diff --git a/server/internal/handler/daemon.go b/server/internal/handler/daemon.go index d3c24690..d0577145 100644 --- a/server/internal/handler/daemon.go +++ b/server/internal/handler/daemon.go @@ -341,21 +341,10 @@ func (h *Handler) ReportTaskProgress(w http.ResponseWriter, r *http.Request) { // CompleteTask marks a running task as completed. type TaskCompleteRequest struct { - PRURL string `json:"pr_url"` - Output string `json:"output"` - SessionID string `json:"session_id"` // Claude session ID for future resumption - WorkDir string `json:"work_dir"` // working directory used during execution - Usage []TaskUsagePayload `json:"usage,omitempty"` -} - -// TaskUsagePayload is the per-model token usage reported by the daemon. -type TaskUsagePayload struct { - Provider string `json:"provider"` - Model string `json:"model"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheReadTokens int64 `json:"cache_read_tokens"` - CacheWriteTokens int64 `json:"cache_write_tokens"` + PRURL string `json:"pr_url"` + Output string `json:"output"` + SessionID string `json:"session_id"` // Claude session ID for future resumption + WorkDir string `json:"work_dir"` // working directory used during execution } func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { @@ -375,7 +364,32 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { return } - // Store per-task token usage (best-effort, don't fail the request). + slog.Info("task completed", "task_id", taskID, "agent_id", uuidToString(task.AgentID)) + writeJSON(w, http.StatusOK, taskToResponse(*task)) +} + +// ReportTaskUsage stores per-task token usage. Called independently of +// complete/fail so usage is captured even when tasks fail or are blocked. +type TaskUsagePayload struct { + Provider string `json:"provider"` + Model string `json:"model"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` +} + +func (h *Handler) ReportTaskUsage(w http.ResponseWriter, r *http.Request) { + taskID := chi.URLParam(r, "taskId") + + var req struct { + Usage []TaskUsagePayload `json:"usage"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request body") + return + } + for _, u := range req.Usage { if err := h.Queries.UpsertTaskUsage(r.Context(), db.UpsertTaskUsageParams{ TaskID: parseUUID(taskID), @@ -390,8 +404,7 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { } } - slog.Info("task completed", "task_id", taskID, "agent_id", uuidToString(task.AgentID)) - writeJSON(w, http.StatusOK, taskToResponse(*task)) + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } // GetTaskStatus returns the current status of a task. diff --git a/server/pkg/agent/codex.go b/server/pkg/agent/codex.go index a76e42ea..772d6a8d 100644 --- a/server/pkg/agent/codex.go +++ b/server/pkg/agent/codex.go @@ -220,11 +220,25 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti finalOutput := output.String() outputMu.Unlock() + // Build usage map from accumulated codex usage. + var usageMap map[string]TokenUsage + c.usageMu.Lock() + u := c.usage + c.usageMu.Unlock() + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { + model := opts.Model + if model == "" { + model = "unknown" + } + usageMap = map[string]TokenUsage{model: u} + } + resCh <- Result{ Status: finalStatus, Output: finalOutput, Error: finalError, DurationMs: duration.Milliseconds(), + Usage: usageMap, } }() @@ -247,6 +261,9 @@ type codexClient struct { notificationProtocol string // "unknown", "legacy", "raw" turnStarted bool completedTurnIDs map[string]bool + + usageMu sync.Mutex + usage TokenUsage // accumulated from turn events } type pendingRPC struct { @@ -498,6 +515,8 @@ func (c *codexClient) handleEvent(msg map[string]any) { }) } case "task_complete": + // Extract usage from legacy task_complete if present. + c.extractUsageFromMap(msg) if c.onTurnDone != nil { c.onTurnDone(false) } @@ -535,6 +554,11 @@ func (c *codexClient) handleRawNotification(method string, params map[string]any c.completedTurnIDs[turnID] = true } + // Extract usage from turn/completed if present (e.g. params.turn.usage). + if turn, ok := params["turn"].(map[string]any); ok { + c.extractUsageFromMap(turn) + } + if c.onTurnDone != nil { c.onTurnDone(aborted) } @@ -618,6 +642,48 @@ func (c *codexClient) handleItemNotification(method string, params map[string]an } } +// extractUsageFromMap extracts token usage from a map that may contain +// "usage", "token_usage", or "tokens" fields. Handles various Codex formats. +func (c *codexClient) extractUsageFromMap(data map[string]any) { + // Try common field names for usage data. + var usageMap map[string]any + for _, key := range []string{"usage", "token_usage", "tokens"} { + if v, ok := data[key].(map[string]any); ok { + usageMap = v + break + } + } + if usageMap == nil { + return + } + + c.usageMu.Lock() + defer c.usageMu.Unlock() + + // Try various key conventions. + c.usage.InputTokens += codexInt64(usageMap, "input_tokens", "input", "prompt_tokens") + c.usage.OutputTokens += codexInt64(usageMap, "output_tokens", "output", "completion_tokens") + c.usage.CacheReadTokens += codexInt64(usageMap, "cache_read_tokens", "cache_read_input_tokens") + c.usage.CacheWriteTokens += codexInt64(usageMap, "cache_write_tokens", "cache_creation_input_tokens") +} + +// codexInt64 returns the first non-zero int64 value from the map for the given keys. +func codexInt64(m map[string]any, keys ...string) int64 { + for _, key := range keys { + switch v := m[key].(type) { + case float64: + if v != 0 { + return int64(v) + } + case int64: + if v != 0 { + return v + } + } + } + return 0 +} + // ── Helpers ── func extractThreadID(result json.RawMessage) string { diff --git a/server/pkg/agent/openclaw.go b/server/pkg/agent/openclaw.go index 96c240d8..a2e6b85c 100644 --- a/server/pkg/agent/openclaw.go +++ b/server/pkg/agent/openclaw.go @@ -96,12 +96,25 @@ func (b *openclawBackend) Execute(ctx context.Context, prompt string, opts ExecO b.cfg.Logger.Info("openclaw finished", "pid", cmd.Process.Pid, "status", scanResult.status, "duration", duration.Round(time.Millisecond).String()) + // Build usage map. OpenClaw doesn't report model per-step, so we + // attribute all usage to the configured model (or "unknown"). + var usage map[string]TokenUsage + u := scanResult.usage + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { + model := opts.Model + if model == "" { + model = "unknown" + } + usage = map[string]TokenUsage{model: u} + } + resCh <- Result{ Status: scanResult.status, Output: scanResult.output, Error: scanResult.errMsg, DurationMs: duration.Milliseconds(), SessionID: scanResult.sessionID, + Usage: usage, } }() @@ -116,6 +129,7 @@ type openclawEventResult struct { errMsg string output string sessionID string + usage TokenUsage } // processEvents reads NDJSON lines from r, dispatches events to ch, and returns @@ -123,6 +137,7 @@ type openclawEventResult struct { func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclawEventResult { var output strings.Builder var sessionID string + var usage TokenUsage finalStatus := "completed" var finalError string @@ -160,7 +175,13 @@ func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclaw case "step_start": trySend(ch, Message{Type: MessageStatus, Status: "running"}) case "step_end": - // Captures final session ID from step_end if present. + // Accumulate token usage from step_end events if present. + if event.Data != nil { + usage.InputTokens += openclawInt64(event.Data, "inputTokens") + usage.OutputTokens += openclawInt64(event.Data, "outputTokens") + usage.CacheReadTokens += openclawInt64(event.Data, "cacheReadTokens") + usage.CacheWriteTokens += openclawInt64(event.Data, "cacheWriteTokens") + } case "result": // The result event only updates status on explicit failure. A // "completed" result is a no-op because finalStatus defaults to @@ -193,6 +214,24 @@ func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclaw errMsg: finalError, output: output.String(), sessionID: sessionID, + usage: usage, + } +} + +// openclawInt64 safely extracts an int64 from a JSON-decoded map value (which +// may be float64 due to Go's JSON number handling). +func openclawInt64(data map[string]any, key string) int64 { + v, ok := data[key] + if !ok { + return 0 + } + switch n := v.(type) { + case float64: + return int64(n) + case int64: + return n + default: + return 0 } } diff --git a/server/pkg/agent/opencode.go b/server/pkg/agent/opencode.go index 66b678f5..4c17a039 100644 --- a/server/pkg/agent/opencode.go +++ b/server/pkg/agent/opencode.go @@ -99,12 +99,25 @@ func (b *opencodeBackend) Execute(ctx context.Context, prompt string, opts ExecO b.cfg.Logger.Info("opencode finished", "pid", cmd.Process.Pid, "status", scanResult.status, "duration", duration.Round(time.Millisecond).String()) + // Build usage map. OpenCode doesn't report model per-step, so we + // attribute all usage to the configured model (or "unknown"). + var usage map[string]TokenUsage + u := scanResult.usage + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { + model := opts.Model + if model == "" { + model = "unknown" + } + usage = map[string]TokenUsage{model: u} + } + resCh <- Result{ Status: scanResult.status, Output: scanResult.output, Error: scanResult.errMsg, DurationMs: duration.Milliseconds(), SessionID: scanResult.sessionID, + Usage: usage, } }() @@ -119,6 +132,7 @@ type eventResult struct { errMsg string output string sessionID string + usage TokenUsage // accumulated token usage across all steps } // processEvents reads JSON lines from r, dispatches events to ch, and returns @@ -126,6 +140,7 @@ type eventResult struct { func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventResult { var output strings.Builder var sessionID string + var usage TokenUsage finalStatus := "completed" var finalError string @@ -157,7 +172,15 @@ func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventRes case "step_start": trySend(ch, Message{Type: MessageStatus, Status: "running"}) case "step_finish": - // Captures final session ID from step_finish if present. + // Accumulate token usage from step_finish events. + if t := event.Part.Tokens; t != nil { + usage.InputTokens += t.Input + usage.OutputTokens += t.Output + if t.Cache != nil { + usage.CacheReadTokens += t.Cache.Read + usage.CacheWriteTokens += t.Cache.Write + } + } } } @@ -175,6 +198,7 @@ func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventRes errMsg: finalError, output: output.String(), sessionID: sessionID, + usage: usage, } } @@ -281,6 +305,21 @@ type opencodeEventPart struct { Tool string `json:"tool,omitempty"` CallID string `json:"callID,omitempty"` State *opencodeToolState `json:"state,omitempty"` + + // step_finish token usage + Tokens *opencodeTokens `json:"tokens,omitempty"` +} + +// opencodeTokens represents token usage in a step_finish event. +type opencodeTokens struct { + Input int64 `json:"input"` + Output int64 `json:"output"` + Cache *opencodeCacheTokens `json:"cache,omitempty"` +} + +type opencodeCacheTokens struct { + Read int64 `json:"read"` + Write int64 `json:"write"` } // opencodeToolState represents the state of a tool invocation.