fix(usage): address review feedback — independent usage reporting + all providers
1. Separate ReportTaskUsage endpoint (POST /api/daemon/tasks/{id}/usage)
so usage is captured independently of complete/fail — fixes usage loss
for failed/blocked tasks.
2. Add usage tracking for all four providers:
- Claude: already done (stream-json message.usage)
- OpenCode: extract from step_finish.part.tokens
- OpenClaw: extract from step_end.data token fields
- Codex: extract from turn/completed and task_complete usage fields
3. Remove usage from CompleteTask payload — all usage goes through the
dedicated endpoint now.
This commit is contained in:
parent
8a8d3ea20e
commit
fa0c0fe747
7 changed files with 196 additions and 25 deletions
|
|
@ -103,6 +103,7 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route
|
|||
r.Post("/tasks/{taskId}/progress", h.ReportTaskProgress)
|
||||
r.Post("/tasks/{taskId}/complete", h.CompleteTask)
|
||||
r.Post("/tasks/{taskId}/fail", h.FailTask)
|
||||
r.Post("/tasks/{taskId}/usage", h.ReportTaskUsage)
|
||||
r.Post("/tasks/{taskId}/messages", h.ReportTaskMessages)
|
||||
r.Get("/tasks/{taskId}/messages", h.ListTaskMessages)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ func (c *Client) ReportTaskMessages(ctx context.Context, taskID string, messages
|
|||
}, nil)
|
||||
}
|
||||
|
||||
func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string, usage []TaskUsageEntry) error {
|
||||
func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string) error {
|
||||
body := map[string]any{"output": output}
|
||||
if branchName != "" {
|
||||
body["branch_name"] = branchName
|
||||
|
|
@ -110,12 +110,18 @@ func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, s
|
|||
if workDir != "" {
|
||||
body["work_dir"] = workDir
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
body["usage"] = usage
|
||||
}
|
||||
return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/complete", taskID), body, nil)
|
||||
}
|
||||
|
||||
func (c *Client) ReportTaskUsage(ctx context.Context, taskID string, usage []TaskUsageEntry) error {
|
||||
if len(usage) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/usage", taskID), map[string]any{
|
||||
"usage": usage,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (c *Client) FailTask(ctx context.Context, taskID, errMsg string) error {
|
||||
return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/fail", taskID), map[string]any{
|
||||
"error": errMsg,
|
||||
|
|
|
|||
|
|
@ -837,6 +837,13 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) {
|
|||
return
|
||||
}
|
||||
|
||||
// Report usage independently so it's captured even for failed/blocked tasks.
|
||||
if len(result.Usage) > 0 {
|
||||
if err := d.client.ReportTaskUsage(ctx, task.ID, result.Usage); err != nil {
|
||||
taskLog.Warn("report task usage failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
switch result.Status {
|
||||
case "blocked":
|
||||
if err := d.client.FailTask(ctx, task.ID, result.Comment); err != nil {
|
||||
|
|
@ -844,7 +851,7 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) {
|
|||
}
|
||||
default:
|
||||
taskLog.Info("task completed", "status", result.Status)
|
||||
if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir, result.Usage); err != nil {
|
||||
if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir); err != nil {
|
||||
taskLog.Error("complete task failed, falling back to fail", "error", err)
|
||||
if failErr := d.client.FailTask(ctx, task.ID, fmt.Sprintf("complete task failed: %s", err.Error())); failErr != nil {
|
||||
taskLog.Error("fail task fallback also failed", "error", failErr)
|
||||
|
|
|
|||
|
|
@ -341,21 +341,10 @@ func (h *Handler) ReportTaskProgress(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// CompleteTask marks a running task as completed.
|
||||
type TaskCompleteRequest struct {
|
||||
PRURL string `json:"pr_url"`
|
||||
Output string `json:"output"`
|
||||
SessionID string `json:"session_id"` // Claude session ID for future resumption
|
||||
WorkDir string `json:"work_dir"` // working directory used during execution
|
||||
Usage []TaskUsagePayload `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// TaskUsagePayload is the per-model token usage reported by the daemon.
|
||||
type TaskUsagePayload struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens int64 `json:"cache_write_tokens"`
|
||||
PRURL string `json:"pr_url"`
|
||||
Output string `json:"output"`
|
||||
SessionID string `json:"session_id"` // Claude session ID for future resumption
|
||||
WorkDir string `json:"work_dir"` // working directory used during execution
|
||||
}
|
||||
|
||||
func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -375,7 +364,32 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Store per-task token usage (best-effort, don't fail the request).
|
||||
slog.Info("task completed", "task_id", taskID, "agent_id", uuidToString(task.AgentID))
|
||||
writeJSON(w, http.StatusOK, taskToResponse(*task))
|
||||
}
|
||||
|
||||
// ReportTaskUsage stores per-task token usage. Called independently of
|
||||
// complete/fail so usage is captured even when tasks fail or are blocked.
|
||||
type TaskUsagePayload struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens int64 `json:"cache_write_tokens"`
|
||||
}
|
||||
|
||||
func (h *Handler) ReportTaskUsage(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
|
||||
var req struct {
|
||||
Usage []TaskUsagePayload `json:"usage"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
for _, u := range req.Usage {
|
||||
if err := h.Queries.UpsertTaskUsage(r.Context(), db.UpsertTaskUsageParams{
|
||||
TaskID: parseUUID(taskID),
|
||||
|
|
@ -390,8 +404,7 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Info("task completed", "task_id", taskID, "agent_id", uuidToString(task.AgentID))
|
||||
writeJSON(w, http.StatusOK, taskToResponse(*task))
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// GetTaskStatus returns the current status of a task.
|
||||
|
|
|
|||
|
|
@ -220,11 +220,25 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
|
|||
finalOutput := output.String()
|
||||
outputMu.Unlock()
|
||||
|
||||
// Build usage map from accumulated codex usage.
|
||||
var usageMap map[string]TokenUsage
|
||||
c.usageMu.Lock()
|
||||
u := c.usage
|
||||
c.usageMu.Unlock()
|
||||
if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 {
|
||||
model := opts.Model
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
usageMap = map[string]TokenUsage{model: u}
|
||||
}
|
||||
|
||||
resCh <- Result{
|
||||
Status: finalStatus,
|
||||
Output: finalOutput,
|
||||
Error: finalError,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Usage: usageMap,
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -247,6 +261,9 @@ type codexClient struct {
|
|||
notificationProtocol string // "unknown", "legacy", "raw"
|
||||
turnStarted bool
|
||||
completedTurnIDs map[string]bool
|
||||
|
||||
usageMu sync.Mutex
|
||||
usage TokenUsage // accumulated from turn events
|
||||
}
|
||||
|
||||
type pendingRPC struct {
|
||||
|
|
@ -498,6 +515,8 @@ func (c *codexClient) handleEvent(msg map[string]any) {
|
|||
})
|
||||
}
|
||||
case "task_complete":
|
||||
// Extract usage from legacy task_complete if present.
|
||||
c.extractUsageFromMap(msg)
|
||||
if c.onTurnDone != nil {
|
||||
c.onTurnDone(false)
|
||||
}
|
||||
|
|
@ -535,6 +554,11 @@ func (c *codexClient) handleRawNotification(method string, params map[string]any
|
|||
c.completedTurnIDs[turnID] = true
|
||||
}
|
||||
|
||||
// Extract usage from turn/completed if present (e.g. params.turn.usage).
|
||||
if turn, ok := params["turn"].(map[string]any); ok {
|
||||
c.extractUsageFromMap(turn)
|
||||
}
|
||||
|
||||
if c.onTurnDone != nil {
|
||||
c.onTurnDone(aborted)
|
||||
}
|
||||
|
|
@ -618,6 +642,48 @@ func (c *codexClient) handleItemNotification(method string, params map[string]an
|
|||
}
|
||||
}
|
||||
|
||||
// extractUsageFromMap extracts token usage from a map that may contain
|
||||
// "usage", "token_usage", or "tokens" fields. Handles various Codex formats.
|
||||
func (c *codexClient) extractUsageFromMap(data map[string]any) {
|
||||
// Try common field names for usage data.
|
||||
var usageMap map[string]any
|
||||
for _, key := range []string{"usage", "token_usage", "tokens"} {
|
||||
if v, ok := data[key].(map[string]any); ok {
|
||||
usageMap = v
|
||||
break
|
||||
}
|
||||
}
|
||||
if usageMap == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.usageMu.Lock()
|
||||
defer c.usageMu.Unlock()
|
||||
|
||||
// Try various key conventions.
|
||||
c.usage.InputTokens += codexInt64(usageMap, "input_tokens", "input", "prompt_tokens")
|
||||
c.usage.OutputTokens += codexInt64(usageMap, "output_tokens", "output", "completion_tokens")
|
||||
c.usage.CacheReadTokens += codexInt64(usageMap, "cache_read_tokens", "cache_read_input_tokens")
|
||||
c.usage.CacheWriteTokens += codexInt64(usageMap, "cache_write_tokens", "cache_creation_input_tokens")
|
||||
}
|
||||
|
||||
// codexInt64 returns the first non-zero int64 value from the map for the given keys.
|
||||
func codexInt64(m map[string]any, keys ...string) int64 {
|
||||
for _, key := range keys {
|
||||
switch v := m[key].(type) {
|
||||
case float64:
|
||||
if v != 0 {
|
||||
return int64(v)
|
||||
}
|
||||
case int64:
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ── Helpers ──
|
||||
|
||||
func extractThreadID(result json.RawMessage) string {
|
||||
|
|
|
|||
|
|
@ -96,12 +96,25 @@ func (b *openclawBackend) Execute(ctx context.Context, prompt string, opts ExecO
|
|||
|
||||
b.cfg.Logger.Info("openclaw finished", "pid", cmd.Process.Pid, "status", scanResult.status, "duration", duration.Round(time.Millisecond).String())
|
||||
|
||||
// Build usage map. OpenClaw doesn't report model per-step, so we
|
||||
// attribute all usage to the configured model (or "unknown").
|
||||
var usage map[string]TokenUsage
|
||||
u := scanResult.usage
|
||||
if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 {
|
||||
model := opts.Model
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
usage = map[string]TokenUsage{model: u}
|
||||
}
|
||||
|
||||
resCh <- Result{
|
||||
Status: scanResult.status,
|
||||
Output: scanResult.output,
|
||||
Error: scanResult.errMsg,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
SessionID: scanResult.sessionID,
|
||||
Usage: usage,
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -116,6 +129,7 @@ type openclawEventResult struct {
|
|||
errMsg string
|
||||
output string
|
||||
sessionID string
|
||||
usage TokenUsage
|
||||
}
|
||||
|
||||
// processEvents reads NDJSON lines from r, dispatches events to ch, and returns
|
||||
|
|
@ -123,6 +137,7 @@ type openclawEventResult struct {
|
|||
func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclawEventResult {
|
||||
var output strings.Builder
|
||||
var sessionID string
|
||||
var usage TokenUsage
|
||||
finalStatus := "completed"
|
||||
var finalError string
|
||||
|
||||
|
|
@ -160,7 +175,13 @@ func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclaw
|
|||
case "step_start":
|
||||
trySend(ch, Message{Type: MessageStatus, Status: "running"})
|
||||
case "step_end":
|
||||
// Captures final session ID from step_end if present.
|
||||
// Accumulate token usage from step_end events if present.
|
||||
if event.Data != nil {
|
||||
usage.InputTokens += openclawInt64(event.Data, "inputTokens")
|
||||
usage.OutputTokens += openclawInt64(event.Data, "outputTokens")
|
||||
usage.CacheReadTokens += openclawInt64(event.Data, "cacheReadTokens")
|
||||
usage.CacheWriteTokens += openclawInt64(event.Data, "cacheWriteTokens")
|
||||
}
|
||||
case "result":
|
||||
// The result event only updates status on explicit failure. A
|
||||
// "completed" result is a no-op because finalStatus defaults to
|
||||
|
|
@ -193,6 +214,24 @@ func (b *openclawBackend) processEvents(r io.Reader, ch chan<- Message) openclaw
|
|||
errMsg: finalError,
|
||||
output: output.String(),
|
||||
sessionID: sessionID,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// openclawInt64 safely extracts an int64 from a JSON-decoded map value (which
|
||||
// may be float64 due to Go's JSON number handling).
|
||||
func openclawInt64(data map[string]any, key string) int64 {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int64(n)
|
||||
case int64:
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -99,12 +99,25 @@ func (b *opencodeBackend) Execute(ctx context.Context, prompt string, opts ExecO
|
|||
|
||||
b.cfg.Logger.Info("opencode finished", "pid", cmd.Process.Pid, "status", scanResult.status, "duration", duration.Round(time.Millisecond).String())
|
||||
|
||||
// Build usage map. OpenCode doesn't report model per-step, so we
|
||||
// attribute all usage to the configured model (or "unknown").
|
||||
var usage map[string]TokenUsage
|
||||
u := scanResult.usage
|
||||
if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 {
|
||||
model := opts.Model
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
usage = map[string]TokenUsage{model: u}
|
||||
}
|
||||
|
||||
resCh <- Result{
|
||||
Status: scanResult.status,
|
||||
Output: scanResult.output,
|
||||
Error: scanResult.errMsg,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
SessionID: scanResult.sessionID,
|
||||
Usage: usage,
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -119,6 +132,7 @@ type eventResult struct {
|
|||
errMsg string
|
||||
output string
|
||||
sessionID string
|
||||
usage TokenUsage // accumulated token usage across all steps
|
||||
}
|
||||
|
||||
// processEvents reads JSON lines from r, dispatches events to ch, and returns
|
||||
|
|
@ -126,6 +140,7 @@ type eventResult struct {
|
|||
func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventResult {
|
||||
var output strings.Builder
|
||||
var sessionID string
|
||||
var usage TokenUsage
|
||||
finalStatus := "completed"
|
||||
var finalError string
|
||||
|
||||
|
|
@ -157,7 +172,15 @@ func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventRes
|
|||
case "step_start":
|
||||
trySend(ch, Message{Type: MessageStatus, Status: "running"})
|
||||
case "step_finish":
|
||||
// Captures final session ID from step_finish if present.
|
||||
// Accumulate token usage from step_finish events.
|
||||
if t := event.Part.Tokens; t != nil {
|
||||
usage.InputTokens += t.Input
|
||||
usage.OutputTokens += t.Output
|
||||
if t.Cache != nil {
|
||||
usage.CacheReadTokens += t.Cache.Read
|
||||
usage.CacheWriteTokens += t.Cache.Write
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -175,6 +198,7 @@ func (b *opencodeBackend) processEvents(r io.Reader, ch chan<- Message) eventRes
|
|||
errMsg: finalError,
|
||||
output: output.String(),
|
||||
sessionID: sessionID,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -281,6 +305,21 @@ type opencodeEventPart struct {
|
|||
Tool string `json:"tool,omitempty"`
|
||||
CallID string `json:"callID,omitempty"`
|
||||
State *opencodeToolState `json:"state,omitempty"`
|
||||
|
||||
// step_finish token usage
|
||||
Tokens *opencodeTokens `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
// opencodeTokens represents token usage in a step_finish event.
|
||||
type opencodeTokens struct {
|
||||
Input int64 `json:"input"`
|
||||
Output int64 `json:"output"`
|
||||
Cache *opencodeCacheTokens `json:"cache,omitempty"`
|
||||
}
|
||||
|
||||
type opencodeCacheTokens struct {
|
||||
Read int64 `json:"read"`
|
||||
Write int64 `json:"write"`
|
||||
}
|
||||
|
||||
// opencodeToolState represents the state of a tool invocation.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue