diff --git a/server/cmd/server/router.go b/server/cmd/server/router.go index dc488e83..07b9d170 100644 --- a/server/cmd/server/router.go +++ b/server/cmd/server/router.go @@ -94,6 +94,7 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route r.Post("/runtimes/{runtimeId}/usage", h.ReportRuntimeUsage) r.Post("/runtimes/{runtimeId}/ping/{pingId}/result", h.ReportPingResult) + r.Get("/tasks/{taskId}/status", h.GetTaskStatus) r.Post("/tasks/{taskId}/start", h.StartTask) r.Post("/tasks/{taskId}/progress", h.ReportTaskProgress) r.Post("/tasks/{taskId}/complete", h.CompleteTask) diff --git a/server/cmd/server/runtime_sweeper.go b/server/cmd/server/runtime_sweeper.go index e1d47a53..f3b637a6 100644 --- a/server/cmd/server/runtime_sweeper.go +++ b/server/cmd/server/runtime_sweeper.go @@ -49,6 +49,26 @@ func runRuntimeSweeper(ctx context.Context, queries *db.Queries, bus *events.Bus slog.Info("runtime sweeper: marked stale runtimes offline", "count", len(staleRows), "workspaces", len(workspaces)) + // Fail orphaned tasks (dispatched/running) whose runtimes just went offline. + failedTasks, err := queries.FailTasksForOfflineRuntimes(ctx) + if err != nil { + slog.Warn("runtime sweeper: failed to clean up stale tasks", "error", err) + } else if len(failedTasks) > 0 { + slog.Info("runtime sweeper: failed orphaned tasks", "count", len(failedTasks)) + for _, ft := range failedTasks { + bus.Publish(events.Event{ + Type: protocol.EventTaskFailed, + ActorType: "system", + Payload: map[string]any{ + "task_id": util.UUIDToString(ft.ID), + "agent_id": util.UUIDToString(ft.AgentID), + "issue_id": util.UUIDToString(ft.IssueID), + "status": "failed", + }, + }) + } + } + // Notify frontend clients so they re-fetch runtime list. for wsID := range workspaces { bus.Publish(events.Event{ diff --git a/server/internal/daemon/client.go b/server/internal/daemon/client.go index e29a0977..865ae177 100644 --- a/server/internal/daemon/client.go +++ b/server/internal/daemon/client.go @@ -103,6 +103,18 @@ func (c *Client) FailTask(ctx context.Context, taskID, errMsg string) error { }, nil) } +// GetTaskStatus returns the current status of a task. Used by the daemon to +// detect if a task was cancelled while it was executing. +func (c *Client) GetTaskStatus(ctx context.Context, taskID string) (string, error) { + var resp struct { + Status string `json:"status"` + } + if err := c.getJSON(ctx, fmt.Sprintf("/api/daemon/tasks/%s/status", taskID), &resp); err != nil { + return "", err + } + return resp.Status, nil +} + func (c *Client) ReportUsage(ctx context.Context, runtimeID string, entries []map[string]any) error { return c.postJSON(ctx, fmt.Sprintf("/api/daemon/runtimes/%s/usage", runtimeID), map[string]any{ "entries": entries, diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index ad3b50a1..22b7aecb 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -557,6 +557,14 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { _ = d.client.ReportProgress(ctx, task.ID, "Finishing task", 2, 2) + // Check if the task was cancelled while it was running (e.g. issue + // was reassigned). If so, skip reporting results — the server already + // moved the task to 'cancelled' so complete/fail would fail anyway. + if status, err := d.client.GetTaskStatus(ctx, task.ID); err == nil && status == "cancelled" { + d.logger.Info("task was cancelled during execution, discarding result", "task_id", task.ID) + return + } + switch result.Status { case "blocked": if err := d.client.FailTask(ctx, task.ID, result.Comment); err != nil { diff --git a/server/internal/handler/daemon.go b/server/internal/handler/daemon.go index 392917b9..a00d998f 100644 --- a/server/internal/handler/daemon.go +++ b/server/internal/handler/daemon.go @@ -325,6 +325,18 @@ func (h *Handler) CompleteTask(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, taskToResponse(*task)) } +// GetTaskStatus returns the current status of a task. +// Used by the daemon to check whether a task was cancelled mid-execution. +func (h *Handler) GetTaskStatus(w http.ResponseWriter, r *http.Request) { + taskID := chi.URLParam(r, "taskId") + task, err := h.Queries.GetAgentTask(r.Context(), parseUUID(taskID)) + if err != nil { + writeError(w, http.StatusNotFound, "task not found") + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": task.Status}) +} + // FailTask marks a running task as failed. type TaskFailRequest struct { Error string `json:"error"` diff --git a/server/internal/handler/issue.go b/server/internal/handler/issue.go index f6f1f736..21a99495 100644 --- a/server/internal/handler/issue.go +++ b/server/internal/handler/issue.go @@ -389,9 +389,11 @@ func (h *Handler) shouldEnqueueOnComment(ctx context.Context, issue db.Issue) bo if !h.isAgentTriggerEnabled(ctx, issue, "on_comment") { return false } - // Don't enqueue if there's already an active task for this issue. - hasActive, err := h.Queries.HasActiveTaskForIssue(ctx, issue.ID) - if err != nil || hasActive { + // Coalescing queue: allow enqueue when a task is running (so the agent + // picks up new comments on the next cycle) but skip if a pending task + // already exists (natural dedup for rapid-fire comments). + hasPending, err := h.Queries.HasPendingTaskForIssue(ctx, issue.ID) + if err != nil || hasPending { return false } return true diff --git a/server/internal/service/task.go b/server/internal/service/task.go index fae0ae32..7f67ca60 100644 --- a/server/internal/service/task.go +++ b/server/internal/service/task.go @@ -104,6 +104,8 @@ func (s *TaskService) ClaimTask(ctx context.Context, agentID pgtype.UUID) (*db.A // ClaimTaskForRuntime claims the next runnable task for a runtime while // still respecting each agent's max_concurrent_tasks limit. +// Tasks whose issues are in a terminal status (done/cancelled) are +// automatically cancelled and skipped. func (s *TaskService) ClaimTaskForRuntime(ctx context.Context, runtimeID pgtype.UUID) (*db.AgentTaskQueue, error) { tasks, err := s.Queries.ListPendingTasksByRuntime(ctx, runtimeID) if err != nil { @@ -112,6 +114,15 @@ func (s *TaskService) ClaimTaskForRuntime(ctx context.Context, runtimeID pgtype. triedAgents := map[string]struct{}{} for _, candidate := range tasks { + // Skip tasks whose issues have reached a terminal status. + if issue, err := s.Queries.GetIssue(ctx, candidate.IssueID); err == nil { + if issue.Status == "done" || issue.Status == "cancelled" { + slog.Info("skipping task for terminal issue", "task_id", util.UUIDToString(candidate.ID), "issue_status", issue.Status) + _ = s.Queries.CancelAgentTasksByIssue(ctx, candidate.IssueID) + continue + } + } + agentKey := util.UUIDToString(candidate.AgentID) if _, seen := triedAgents[agentKey]; seen { continue diff --git a/server/migrations/022_task_lifecycle_guards.down.sql b/server/migrations/022_task_lifecycle_guards.down.sql new file mode 100644 index 00000000..50a369cd --- /dev/null +++ b/server/migrations/022_task_lifecycle_guards.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_one_pending_task_per_issue; diff --git a/server/migrations/022_task_lifecycle_guards.up.sql b/server/migrations/022_task_lifecycle_guards.up.sql new file mode 100644 index 00000000..4778cc16 --- /dev/null +++ b/server/migrations/022_task_lifecycle_guards.up.sql @@ -0,0 +1,5 @@ +-- Prevent duplicate pending tasks for the same issue (coalescing queue safety net). +-- At most one queued/dispatched task per issue at any time. +CREATE UNIQUE INDEX idx_one_pending_task_per_issue + ON agent_task_queue (issue_id) + WHERE status IN ('queued', 'dispatched'); diff --git a/server/pkg/db/generated/agent.sql.go b/server/pkg/db/generated/agent.sql.go index bb420c06..d3d5edc8 100644 --- a/server/pkg/db/generated/agent.sql.go +++ b/server/pkg/db/generated/agent.sql.go @@ -358,6 +358,22 @@ func (q *Queries) HasActiveTaskForIssue(ctx context.Context, issueID pgtype.UUID return has_active, err } +const hasPendingTaskForIssue = `-- name: HasPendingTaskForIssue :one +SELECT count(*) > 0 AS has_pending FROM agent_task_queue +WHERE issue_id = $1 AND status IN ('queued', 'dispatched') +` + +// Returns true if there is a queued or dispatched (but not yet running) task for the issue. +// Used by the coalescing queue: allow enqueue when a task is running (so +// the agent picks up new comments on the next cycle) but skip if a pending +// task already exists (natural dedup). +func (q *Queries) HasPendingTaskForIssue(ctx context.Context, issueID pgtype.UUID) (bool, error) { + row := q.db.QueryRow(ctx, hasPendingTaskForIssue, issueID) + var has_pending bool + err := row.Scan(&has_pending) + return has_pending, err +} + const listAgentTasks = `-- name: ListAgentTasks :many SELECT id, agent_id, issue_id, status, priority, dispatched_at, started_at, completed_at, result, error, created_at, context, runtime_id, session_id, work_dir FROM agent_task_queue WHERE agent_id = $1 diff --git a/server/pkg/db/generated/runtime.sql.go b/server/pkg/db/generated/runtime.sql.go index d871d72d..162520ae 100644 --- a/server/pkg/db/generated/runtime.sql.go +++ b/server/pkg/db/generated/runtime.sql.go @@ -11,6 +11,44 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const failTasksForOfflineRuntimes = `-- name: FailTasksForOfflineRuntimes :many +UPDATE agent_task_queue +SET status = 'failed', completed_at = now(), error = 'runtime went offline' +WHERE status IN ('dispatched', 'running') + AND runtime_id IN ( + SELECT id FROM agent_runtime WHERE status = 'offline' + ) +RETURNING id, agent_id, issue_id +` + +type FailTasksForOfflineRuntimesRow struct { + ID pgtype.UUID `json:"id"` + AgentID pgtype.UUID `json:"agent_id"` + IssueID pgtype.UUID `json:"issue_id"` +} + +// Marks dispatched/running tasks as failed when their runtime is offline. +// This cleans up orphaned tasks after a daemon crash or network partition. +func (q *Queries) FailTasksForOfflineRuntimes(ctx context.Context) ([]FailTasksForOfflineRuntimesRow, error) { + rows, err := q.db.Query(ctx, failTasksForOfflineRuntimes) + if err != nil { + return nil, err + } + defer rows.Close() + items := []FailTasksForOfflineRuntimesRow{} + for rows.Next() { + var i FailTasksForOfflineRuntimesRow + if err := rows.Scan(&i.ID, &i.AgentID, &i.IssueID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAgentRuntime = `-- name: GetAgentRuntime :one SELECT id, workspace_id, daemon_id, name, runtime_mode, provider, status, device_info, metadata, last_seen_at, created_at, updated_at FROM agent_runtime WHERE id = $1 diff --git a/server/pkg/db/queries/agent.sql b/server/pkg/db/queries/agent.sql index f72a353d..70542292 100644 --- a/server/pkg/db/queries/agent.sql +++ b/server/pkg/db/queries/agent.sql @@ -102,6 +102,14 @@ WHERE agent_id = $1 AND status IN ('dispatched', 'running'); SELECT count(*) > 0 AS has_active FROM agent_task_queue WHERE issue_id = $1 AND status IN ('queued', 'dispatched', 'running'); +-- name: HasPendingTaskForIssue :one +-- Returns true if there is a queued or dispatched (but not yet running) task for the issue. +-- Used by the coalescing queue: allow enqueue when a task is running (so +-- the agent picks up new comments on the next cycle) but skip if a pending +-- task already exists (natural dedup). +SELECT count(*) > 0 AS has_pending FROM agent_task_queue +WHERE issue_id = $1 AND status IN ('queued', 'dispatched'); + -- name: ListPendingTasksByRuntime :many SELECT * FROM agent_task_queue WHERE runtime_id = $1 AND status IN ('queued', 'dispatched') diff --git a/server/pkg/db/queries/runtime.sql b/server/pkg/db/queries/runtime.sql index 6aabb657..415fa63c 100644 --- a/server/pkg/db/queries/runtime.sql +++ b/server/pkg/db/queries/runtime.sql @@ -51,3 +51,14 @@ SET status = 'offline', updated_at = now() WHERE status = 'online' AND last_seen_at < now() - make_interval(secs => @stale_seconds::double precision) RETURNING id, workspace_id; + +-- name: FailTasksForOfflineRuntimes :many +-- Marks dispatched/running tasks as failed when their runtime is offline. +-- This cleans up orphaned tasks after a daemon crash or network partition. +UPDATE agent_task_queue +SET status = 'failed', completed_at = now(), error = 'runtime went offline' +WHERE status IN ('dispatched', 'running') + AND runtime_id IN ( + SELECT id FROM agent_runtime WHERE status = 'offline' + ) +RETURNING id, agent_id, issue_id;