diff --git a/server/cmd/server/router.go b/server/cmd/server/router.go index a7500007..0faa51c6 100644 --- a/server/cmd/server/router.go +++ b/server/cmd/server/router.go @@ -221,6 +221,12 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route }) }) + // Usage + r.Route("/api/usage", func(r chi.Router) { + r.Get("/daily", h.GetWorkspaceUsageByDay) + r.Get("/summary", h.GetWorkspaceUsageSummary) + }) + // Runtimes r.Route("/api/runtimes", func(r chi.Router) { r.Get("/", h.ListAgentRuntimes) diff --git a/server/internal/daemon/client.go b/server/internal/daemon/client.go index 5a50ae92..eb55ebf4 100644 --- a/server/internal/daemon/client.go +++ b/server/internal/daemon/client.go @@ -99,7 +99,7 @@ func (c *Client) ReportTaskMessages(ctx context.Context, taskID string, messages }, nil) } -func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string) error { +func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, sessionID, workDir string, usage []TaskUsageEntry) error { body := map[string]any{"output": output} if branchName != "" { body["branch_name"] = branchName @@ -110,6 +110,9 @@ func (c *Client) CompleteTask(ctx context.Context, taskID, output, branchName, s if workDir != "" { body["work_dir"] = workDir } + if len(usage) > 0 { + body["usage"] = usage + } return c.postJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/complete", taskID), body, nil) } diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index b264d04c..2470dcc9 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -844,7 +844,7 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { } default: taskLog.Info("task completed", "status", result.Status) - if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir); err != nil { + if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir, result.Usage); err != nil { taskLog.Error("complete task failed, falling back to fail", "error", err) if failErr := d.client.FailTask(ctx, task.ID, fmt.Sprintf("complete task failed: %s", err.Error())); failErr != nil { taskLog.Error("fail task fallback also failed", "error", failErr) @@ -1105,6 +1105,22 @@ func (d *Daemon) runTask(ctx context.Context, task Task, provider string, taskLo "tools", toolCount.Load(), ) + // Convert agent usage map to task usage entries. + var usageEntries []TaskUsageEntry + for model, u := range result.Usage { + if u.InputTokens == 0 && u.OutputTokens == 0 && u.CacheReadTokens == 0 && u.CacheWriteTokens == 0 { + continue + } + usageEntries = append(usageEntries, TaskUsageEntry{ + Provider: provider, + Model: model, + InputTokens: u.InputTokens, + OutputTokens: u.OutputTokens, + CacheReadTokens: u.CacheReadTokens, + CacheWriteTokens: u.CacheWriteTokens, + }) + } + switch result.Status { case "completed": if result.Output == "" { @@ -1115,6 +1131,7 @@ func (d *Daemon) runTask(ctx context.Context, task Task, provider string, taskLo Comment: result.Output, SessionID: result.SessionID, WorkDir: env.WorkDir, + Usage: usageEntries, }, nil case "timeout": return TaskResult{}, fmt.Errorf("%s timed out after %s", provider, d.cfg.AgentTimeout) @@ -1123,7 +1140,7 @@ func (d *Daemon) runTask(ctx context.Context, task Task, provider string, taskLo if errMsg == "" { errMsg = fmt.Sprintf("%s execution %s", provider, result.Status) } - return TaskResult{Status: "blocked", Comment: errMsg}, nil + return TaskResult{Status: "blocked", Comment: errMsg, Usage: usageEntries}, nil } } diff --git a/server/internal/daemon/types.go b/server/internal/daemon/types.go index accafc3f..bede3f5c 100644 --- a/server/internal/daemon/types.go +++ b/server/internal/daemon/types.go @@ -56,12 +56,23 @@ type SkillFileData struct { Content string `json:"content"` } +// TaskUsageEntry represents token usage for a single model during a task execution. +type TaskUsageEntry struct { + Provider string `json:"provider"` + Model string `json:"model"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` +} + // TaskResult is the outcome of executing a task. type TaskResult struct { - Status string `json:"status"` - Comment string `json:"comment"` - BranchName string `json:"branch_name,omitempty"` - EnvType string `json:"env_type,omitempty"` - SessionID string `json:"session_id,omitempty"` // Claude session ID for future resumption - WorkDir string `json:"work_dir,omitempty"` // working directory used during execution + Status string `json:"status"` + Comment string `json:"comment"` + BranchName string `json:"branch_name,omitempty"` + EnvType string `json:"env_type,omitempty"` + SessionID string `json:"session_id,omitempty"` // Claude session ID for future resumption + WorkDir string `json:"work_dir,omitempty"` // working directory used during execution + Usage []TaskUsageEntry `json:"usage,omitempty"` // per-model token usage } diff --git a/server/internal/handler/daemon.go b/server/internal/handler/daemon.go index bab4ce4d..d3c24690 100644 --- a/server/internal/handler/daemon.go +++ b/server/internal/handler/daemon.go @@ -341,10 +341,21 @@ func (h *Handler) ReportTaskProgress(w http.ResponseWriter, r *http.Request) { // CompleteTask marks a running task as completed. type TaskCompleteRequest struct { - PRURL string `json:"pr_url"` - Output string `json:"output"` - SessionID string `json:"session_id"` // Claude session ID for future resumption - WorkDir string `json:"work_dir"` // working directory used during execution + PRURL string `json:"pr_url"` + Output string `json:"output"` + SessionID string `json:"session_id"` // Claude session ID for future resumption + WorkDir string `json:"work_dir"` // working directory used during execution + Usage []TaskUsagePayload `json:"usage,omitempty"` +} + +// TaskUsagePayload is the per-model token usage reported by the daemon. +type TaskUsagePayload struct { + Provider string `json:"provider"` + Model string `json:"model"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` } func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { @@ -364,6 +375,21 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { return } + // Store per-task token usage (best-effort, don't fail the request). + for _, u := range req.Usage { + if err := h.Queries.UpsertTaskUsage(r.Context(), db.UpsertTaskUsageParams{ + TaskID: parseUUID(taskID), + Provider: u.Provider, + Model: u.Model, + InputTokens: u.InputTokens, + OutputTokens: u.OutputTokens, + CacheReadTokens: u.CacheReadTokens, + CacheWriteTokens: u.CacheWriteTokens, + }); err != nil { + slog.Warn("upsert task usage failed", "task_id", taskID, "model", u.Model, "error", err) + } + } + slog.Info("task completed", "task_id", taskID, "agent_id", uuidToString(task.AgentID)) writeJSON(w, http.StatusOK, taskToResponse(*task)) } diff --git a/server/internal/handler/runtime.go b/server/internal/handler/runtime.go index e9aa5263..d85e4494 100644 --- a/server/internal/handler/runtime.go +++ b/server/internal/handler/runtime.go @@ -192,6 +192,96 @@ func (h *Handler) GetRuntimeTaskActivity(w http.ResponseWriter, r *http.Request) writeJSON(w, http.StatusOK, resp) } +// GetWorkspaceUsageByDay returns daily token usage aggregated by model for the workspace. +func (h *Handler) GetWorkspaceUsageByDay(w http.ResponseWriter, r *http.Request) { + workspaceID := resolveWorkspaceID(r) + since := parseSinceParam(r, 30) + + rows, err := h.Queries.GetWorkspaceUsageByDay(r.Context(), db.GetWorkspaceUsageByDayParams{ + WorkspaceID: parseUUID(workspaceID), + Since: since, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get usage") + return + } + + type DailyUsageRow struct { + Date string `json:"date"` + Model string `json:"model"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheReadTokens int64 `json:"total_cache_read_tokens"` + TotalCacheWriteTokens int64 `json:"total_cache_write_tokens"` + TaskCount int32 `json:"task_count"` + } + + resp := make([]DailyUsageRow, len(rows)) + for i, row := range rows { + resp[i] = DailyUsageRow{ + Date: row.Date.Time.Format("2006-01-02"), + Model: row.Model, + TotalInputTokens: row.TotalInputTokens, + TotalOutputTokens: row.TotalOutputTokens, + TotalCacheReadTokens: row.TotalCacheReadTokens, + TotalCacheWriteTokens: row.TotalCacheWriteTokens, + TaskCount: row.TaskCount, + } + } + + writeJSON(w, http.StatusOK, resp) +} + +// GetWorkspaceUsageSummary returns total token usage aggregated by model for the workspace. +func (h *Handler) GetWorkspaceUsageSummary(w http.ResponseWriter, r *http.Request) { + workspaceID := resolveWorkspaceID(r) + since := parseSinceParam(r, 30) + + rows, err := h.Queries.GetWorkspaceUsageSummary(r.Context(), db.GetWorkspaceUsageSummaryParams{ + WorkspaceID: parseUUID(workspaceID), + Since: since, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get usage summary") + return + } + + type UsageSummaryRow struct { + Model string `json:"model"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheReadTokens int64 `json:"total_cache_read_tokens"` + TotalCacheWriteTokens int64 `json:"total_cache_write_tokens"` + TaskCount int32 `json:"task_count"` + } + + resp := make([]UsageSummaryRow, len(rows)) + for i, row := range rows { + resp[i] = UsageSummaryRow{ + Model: row.Model, + TotalInputTokens: row.TotalInputTokens, + TotalOutputTokens: row.TotalOutputTokens, + TotalCacheReadTokens: row.TotalCacheReadTokens, + TotalCacheWriteTokens: row.TotalCacheWriteTokens, + TaskCount: row.TaskCount, + } + } + + writeJSON(w, http.StatusOK, resp) +} + +// parseSinceParam parses the "days" query parameter and returns a timestamptz. +func parseSinceParam(r *http.Request, defaultDays int) pgtype.Timestamptz { + days := defaultDays + if d := r.URL.Query().Get("days"); d != "" { + if parsed, err := strconv.Atoi(d); err == nil && parsed > 0 && parsed <= 365 { + days = parsed + } + } + t := time.Now().AddDate(0, 0, -days) + return pgtype.Timestamptz{Time: t, Valid: true} +} + func (h *Handler) ListAgentRuntimes(w http.ResponseWriter, r *http.Request) { workspaceID := resolveWorkspaceID(r) diff --git a/server/migrations/032_task_usage.down.sql b/server/migrations/032_task_usage.down.sql new file mode 100644 index 00000000..3ce18482 --- /dev/null +++ b/server/migrations/032_task_usage.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS task_usage; diff --git a/server/migrations/032_task_usage.up.sql b/server/migrations/032_task_usage.up.sql new file mode 100644 index 00000000..1c61c0e0 --- /dev/null +++ b/server/migrations/032_task_usage.up.sql @@ -0,0 +1,14 @@ +CREATE TABLE task_usage ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + task_id UUID NOT NULL REFERENCES agent_task_queue(id) ON DELETE CASCADE, + provider TEXT NOT NULL DEFAULT '', + model TEXT NOT NULL, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cache_read_tokens BIGINT NOT NULL DEFAULT 0, + cache_write_tokens BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (task_id, provider, model) +); + +CREATE INDEX idx_task_usage_task_id ON task_usage(task_id); diff --git a/server/pkg/agent/agent.go b/server/pkg/agent/agent.go index 5636617d..08ab57f7 100644 --- a/server/pkg/agent/agent.go +++ b/server/pkg/agent/agent.go @@ -62,6 +62,14 @@ type Message struct { Level string // log level (Log) } +// TokenUsage tracks token consumption for a single model. +type TokenUsage struct { + InputTokens int64 + OutputTokens int64 + CacheReadTokens int64 + CacheWriteTokens int64 +} + // Result is the final outcome after an agent session completes. type Result struct { Status string // "completed", "failed", "aborted", "timeout" @@ -69,6 +77,7 @@ type Result struct { Error string // error message if failed DurationMs int64 SessionID string + Usage map[string]TokenUsage // keyed by model name } // Config configures a Backend instance. diff --git a/server/pkg/agent/claude.go b/server/pkg/agent/claude.go index c1b78406..4c29cd1e 100644 --- a/server/pkg/agent/claude.go +++ b/server/pkg/agent/claude.go @@ -91,6 +91,7 @@ func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOpt var sessionID string finalStatus := "completed" var finalError string + usage := make(map[string]TokenUsage) scanner := bufio.NewScanner(stdout) scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) @@ -108,7 +109,7 @@ func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOpt switch msg.Type { case "assistant": - b.handleAssistant(msg, msgCh, &output) + b.handleAssistant(msg, msgCh, &output, usage) case "user": b.handleUser(msg, msgCh) case "system": @@ -162,18 +163,29 @@ func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOpt Error: finalError, DurationMs: duration.Milliseconds(), SessionID: sessionID, + Usage: usage, } }() return &Session{Messages: msgCh, Result: resCh}, nil } -func (b *claudeBackend) handleAssistant(msg claudeSDKMessage, ch chan<- Message, output *strings.Builder) { +func (b *claudeBackend) handleAssistant(msg claudeSDKMessage, ch chan<- Message, output *strings.Builder, usage map[string]TokenUsage) { var content claudeMessageContent if err := json.Unmarshal(msg.Message, &content); err != nil { return } + // Accumulate token usage per model. + if content.Usage != nil && content.Model != "" { + u := usage[content.Model] + u.InputTokens += content.Usage.InputTokens + u.OutputTokens += content.Usage.OutputTokens + u.CacheReadTokens += content.Usage.CacheReadInputTokens + u.CacheWriteTokens += content.Usage.CacheCreationInputTokens + usage[content.Model] = u + } + for _, block := range content.Content { switch block.Type { case "text": @@ -287,8 +299,17 @@ type claudeLogEntry struct { } type claudeMessageContent struct { - Role string `json:"role"` + Role string `json:"role"` + Model string `json:"model"` Content []claudeContentBlock `json:"content"` + Usage *claudeUsage `json:"usage,omitempty"` +} + +type claudeUsage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadInputTokens int64 `json:"cache_read_input_tokens"` + CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` } type claudeContentBlock struct { diff --git a/server/pkg/agent/claude_test.go b/server/pkg/agent/claude_test.go index a6c36210..f74561cc 100644 --- a/server/pkg/agent/claude_test.go +++ b/server/pkg/agent/claude_test.go @@ -25,7 +25,7 @@ func TestClaudeHandleAssistantText(t *testing.T) { }), } - b.handleAssistant(msg, ch, &output) + b.handleAssistant(msg, ch, &output, make(map[string]TokenUsage)) if output.String() != "Hello world" { t.Fatalf("expected output 'Hello world', got %q", output.String()) @@ -62,7 +62,7 @@ func TestClaudeHandleAssistantToolUse(t *testing.T) { }), } - b.handleAssistant(msg, ch, &output) + b.handleAssistant(msg, ch, &output, make(map[string]TokenUsage)) if output.String() != "" { t.Fatalf("tool_use should not add to output, got %q", output.String()) @@ -162,7 +162,7 @@ func TestClaudeHandleAssistantInvalidJSON(t *testing.T) { } // Should not panic - b.handleAssistant(msg, ch, &output) + b.handleAssistant(msg, ch, &output, make(map[string]TokenUsage)) if output.String() != "" { t.Fatalf("expected empty output for invalid JSON, got %q", output.String()) diff --git a/server/pkg/db/generated/models.go b/server/pkg/db/generated/models.go index 1aa34fe7..02b33f3e 100644 --- a/server/pkg/db/generated/models.go +++ b/server/pkg/db/generated/models.go @@ -281,6 +281,18 @@ type TaskMessage struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } +type TaskUsage struct { + ID pgtype.UUID `json:"id"` + TaskID pgtype.UUID `json:"task_id"` + Provider string `json:"provider"` + Model string `json:"model"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type User struct { ID pgtype.UUID `json:"id"` Name string `json:"name"` diff --git a/server/pkg/db/generated/task_usage.sql.go b/server/pkg/db/generated/task_usage.sql.go new file mode 100644 index 00000000..16d3c0f3 --- /dev/null +++ b/server/pkg/db/generated/task_usage.sql.go @@ -0,0 +1,201 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: task_usage.sql + +package db + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const getTaskUsage = `-- name: GetTaskUsage :many +SELECT id, task_id, provider, model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, created_at FROM task_usage +WHERE task_id = $1 +ORDER BY model +` + +func (q *Queries) GetTaskUsage(ctx context.Context, taskID pgtype.UUID) ([]TaskUsage, error) { + rows, err := q.db.Query(ctx, getTaskUsage, taskID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []TaskUsage{} + for rows.Next() { + var i TaskUsage + if err := rows.Scan( + &i.ID, + &i.TaskID, + &i.Provider, + &i.Model, + &i.InputTokens, + &i.OutputTokens, + &i.CacheReadTokens, + &i.CacheWriteTokens, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getWorkspaceUsageByDay = `-- name: GetWorkspaceUsageByDay :many +SELECT + DATE(atq.created_at) AS date, + tu.model, + SUM(tu.input_tokens)::bigint AS total_input_tokens, + SUM(tu.output_tokens)::bigint AS total_output_tokens, + SUM(tu.cache_read_tokens)::bigint AS total_cache_read_tokens, + SUM(tu.cache_write_tokens)::bigint AS total_cache_write_tokens, + COUNT(DISTINCT tu.task_id)::int AS task_count +FROM task_usage tu +JOIN agent_task_queue atq ON atq.id = tu.task_id +JOIN agent a ON a.id = atq.agent_id +WHERE a.workspace_id = $1 + AND atq.created_at >= $2::timestamptz +GROUP BY DATE(atq.created_at), tu.model +ORDER BY DATE(atq.created_at) DESC, tu.model +` + +type GetWorkspaceUsageByDayParams struct { + WorkspaceID pgtype.UUID `json:"workspace_id"` + Since pgtype.Timestamptz `json:"since"` +} + +type GetWorkspaceUsageByDayRow struct { + Date pgtype.Date `json:"date"` + Model string `json:"model"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheReadTokens int64 `json:"total_cache_read_tokens"` + TotalCacheWriteTokens int64 `json:"total_cache_write_tokens"` + TaskCount int32 `json:"task_count"` +} + +func (q *Queries) GetWorkspaceUsageByDay(ctx context.Context, arg GetWorkspaceUsageByDayParams) ([]GetWorkspaceUsageByDayRow, error) { + rows, err := q.db.Query(ctx, getWorkspaceUsageByDay, arg.WorkspaceID, arg.Since) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetWorkspaceUsageByDayRow{} + for rows.Next() { + var i GetWorkspaceUsageByDayRow + if err := rows.Scan( + &i.Date, + &i.Model, + &i.TotalInputTokens, + &i.TotalOutputTokens, + &i.TotalCacheReadTokens, + &i.TotalCacheWriteTokens, + &i.TaskCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getWorkspaceUsageSummary = `-- name: GetWorkspaceUsageSummary :many +SELECT + tu.model, + SUM(tu.input_tokens)::bigint AS total_input_tokens, + SUM(tu.output_tokens)::bigint AS total_output_tokens, + SUM(tu.cache_read_tokens)::bigint AS total_cache_read_tokens, + SUM(tu.cache_write_tokens)::bigint AS total_cache_write_tokens, + COUNT(DISTINCT tu.task_id)::int AS task_count +FROM task_usage tu +JOIN agent_task_queue atq ON atq.id = tu.task_id +JOIN agent a ON a.id = atq.agent_id +WHERE a.workspace_id = $1 + AND atq.created_at >= $2::timestamptz +GROUP BY tu.model +ORDER BY (SUM(tu.input_tokens) + SUM(tu.output_tokens)) DESC +` + +type GetWorkspaceUsageSummaryParams struct { + WorkspaceID pgtype.UUID `json:"workspace_id"` + Since pgtype.Timestamptz `json:"since"` +} + +type GetWorkspaceUsageSummaryRow struct { + Model string `json:"model"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheReadTokens int64 `json:"total_cache_read_tokens"` + TotalCacheWriteTokens int64 `json:"total_cache_write_tokens"` + TaskCount int32 `json:"task_count"` +} + +func (q *Queries) GetWorkspaceUsageSummary(ctx context.Context, arg GetWorkspaceUsageSummaryParams) ([]GetWorkspaceUsageSummaryRow, error) { + rows, err := q.db.Query(ctx, getWorkspaceUsageSummary, arg.WorkspaceID, arg.Since) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetWorkspaceUsageSummaryRow{} + for rows.Next() { + var i GetWorkspaceUsageSummaryRow + if err := rows.Scan( + &i.Model, + &i.TotalInputTokens, + &i.TotalOutputTokens, + &i.TotalCacheReadTokens, + &i.TotalCacheWriteTokens, + &i.TaskCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const upsertTaskUsage = `-- name: UpsertTaskUsage :exec +INSERT INTO task_usage (task_id, provider, model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens) +VALUES ($1, $2, $3, $4, $5, $6, $7) +ON CONFLICT (task_id, provider, model) +DO UPDATE SET + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + cache_write_tokens = EXCLUDED.cache_write_tokens +` + +type UpsertTaskUsageParams struct { + TaskID pgtype.UUID `json:"task_id"` + Provider string `json:"provider"` + Model string `json:"model"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` +} + +func (q *Queries) UpsertTaskUsage(ctx context.Context, arg UpsertTaskUsageParams) error { + _, err := q.db.Exec(ctx, upsertTaskUsage, + arg.TaskID, + arg.Provider, + arg.Model, + arg.InputTokens, + arg.OutputTokens, + arg.CacheReadTokens, + arg.CacheWriteTokens, + ) + return err +} diff --git a/server/pkg/db/queries/task_usage.sql b/server/pkg/db/queries/task_usage.sql new file mode 100644 index 00000000..569b131d --- /dev/null +++ b/server/pkg/db/queries/task_usage.sql @@ -0,0 +1,47 @@ +-- name: UpsertTaskUsage :exec +INSERT INTO task_usage (task_id, provider, model, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens) +VALUES ($1, $2, $3, $4, $5, $6, $7) +ON CONFLICT (task_id, provider, model) +DO UPDATE SET + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + cache_write_tokens = EXCLUDED.cache_write_tokens; + +-- name: GetTaskUsage :many +SELECT * FROM task_usage +WHERE task_id = $1 +ORDER BY model; + +-- name: GetWorkspaceUsageByDay :many +SELECT + DATE(atq.created_at) AS date, + tu.model, + SUM(tu.input_tokens)::bigint AS total_input_tokens, + SUM(tu.output_tokens)::bigint AS total_output_tokens, + SUM(tu.cache_read_tokens)::bigint AS total_cache_read_tokens, + SUM(tu.cache_write_tokens)::bigint AS total_cache_write_tokens, + COUNT(DISTINCT tu.task_id)::int AS task_count +FROM task_usage tu +JOIN agent_task_queue atq ON atq.id = tu.task_id +JOIN agent a ON a.id = atq.agent_id +WHERE a.workspace_id = $1 + AND atq.created_at >= @since::timestamptz +GROUP BY DATE(atq.created_at), tu.model +ORDER BY DATE(atq.created_at) DESC, tu.model; + +-- name: GetWorkspaceUsageSummary :many +SELECT + tu.model, + SUM(tu.input_tokens)::bigint AS total_input_tokens, + SUM(tu.output_tokens)::bigint AS total_output_tokens, + SUM(tu.cache_read_tokens)::bigint AS total_cache_read_tokens, + SUM(tu.cache_write_tokens)::bigint AS total_cache_write_tokens, + COUNT(DISTINCT tu.task_id)::int AS task_count +FROM task_usage tu +JOIN agent_task_queue atq ON atq.id = tu.task_id +JOIN agent a ON a.id = atq.agent_id +WHERE a.workspace_id = $1 + AND atq.created_at >= @since::timestamptz +GROUP BY tu.model +ORDER BY (SUM(tu.input_tokens) + SUM(tu.output_tokens)) DESC;