diff --git a/apps/web/features/issues/components/agent-live-card.tsx b/apps/web/features/issues/components/agent-live-card.tsx index 4c9e0b15..6917fc0f 100644 --- a/apps/web/features/issues/components/agent-live-card.tsx +++ b/apps/web/features/issues/components/agent-live-card.tsx @@ -1,10 +1,10 @@ "use client"; import { useState, useEffect, useCallback, useRef } from "react"; -import { Bot, ChevronRight, Loader2, ArrowDown, Brain, AlertCircle, Clock, CheckCircle2, XCircle } from "lucide-react"; +import { Bot, ChevronRight, Loader2, ArrowDown, Brain, AlertCircle, Clock, CheckCircle2, XCircle, Square } from "lucide-react"; import { api } from "@/shared/api"; import { useWSEvent } from "@/features/realtime"; -import type { TaskMessagePayload, TaskCompletedPayload, TaskFailedPayload } from "@/shared/types/events"; +import type { TaskMessagePayload, TaskCompletedPayload, TaskFailedPayload, TaskCancelledPayload } from "@/shared/types/events"; import type { AgentTask } from "@/shared/types/agent"; import { cn } from "@/lib/utils"; import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; @@ -106,6 +106,7 @@ export function AgentLiveCard({ issueId, agentName }: AgentLiveCardProps) { const [items, setItems] = useState([]); const [elapsed, setElapsed] = useState(""); const [autoScroll, setAutoScroll] = useState(true); + const [cancelling, setCancelling] = useState(false); const scrollRef = useRef(null); const seenSeqs = useRef(new Set()); @@ -165,6 +166,7 @@ export function AgentLiveCard({ issueId, agentName }: AgentLiveCardProps) { setActiveTask(null); setItems([]); seenSeqs.current.clear(); + setCancelling(false); }, [issueId]), ); @@ -176,6 +178,19 @@ export function AgentLiveCard({ issueId, agentName }: AgentLiveCardProps) { setActiveTask(null); setItems([]); seenSeqs.current.clear(); + setCancelling(false); + }, [issueId]), + ); + + useWSEvent( + "task:cancelled", + useCallback((payload: unknown) => { + const p = payload as TaskCancelledPayload; + if (p.issue_id !== issueId) return; + setActiveTask(null); + setItems([]); + seenSeqs.current.clear(); + setCancelling(false); }, [issueId]), ); @@ -215,6 +230,16 @@ export function AgentLiveCard({ issueId, agentName }: AgentLiveCardProps) { setAutoScroll(scrollHeight - scrollTop - clientHeight < 40); }, []); + const handleCancel = useCallback(async () => { + if (!activeTask || cancelling) return; + setCancelling(true); + try { + await api.cancelTask(issueId, activeTask.id); + } catch { + setCancelling(false); + } + }, [activeTask, issueId, cancelling]); + if (!activeTask) return null; const toolCount = items.filter((i) => i.type === "tool_use").length; @@ -236,6 +261,19 @@ export function AgentLiveCard({ issueId, agentName }: AgentLiveCardProps) { {toolCount} tool {toolCount === 1 ? "call" : "calls"} )} + {/* Timeline content */} @@ -302,7 +340,17 @@ export function TaskRunHistory({ issueId }: TaskRunHistoryProps) { }, [issueId]), ); - const completedTasks = tasks.filter((t) => t.status === "completed" || t.status === "failed"); + // Refresh when a task is cancelled + useWSEvent( + "task:cancelled", + useCallback((payload: unknown) => { + const p = payload as TaskCancelledPayload; + if (p.issue_id !== issueId) return; + api.listTasksByIssue(issueId).then(setTasks).catch(() => {}); + }, [issueId]), + ); + + const completedTasks = tasks.filter((t) => t.status === "completed" || t.status === "failed" || t.status === "cancelled"); if (completedTasks.length === 0) return null; return ( diff --git a/apps/web/next-env.d.ts b/apps/web/next-env.d.ts index 9edff1c7..c4b7818f 100644 --- a/apps/web/next-env.d.ts +++ b/apps/web/next-env.d.ts @@ -1,6 +1,6 @@ /// /// -import "./.next/types/routes.d.ts"; +import "./.next/dev/types/routes.d.ts"; // NOTE: This file should not be edited // see https://nextjs.org/docs/app/api-reference/config/typescript for more information. diff --git a/apps/web/shared/api/client.ts b/apps/web/shared/api/client.ts index 8c778806..fb309433 100644 --- a/apps/web/shared/api/client.ts +++ b/apps/web/shared/api/client.ts @@ -360,6 +360,12 @@ export class ApiClient { return this.fetch(`/api/issues/${issueId}/task-runs`); } + async cancelTask(issueId: string, taskId: string): Promise { + return this.fetch(`/api/issues/${issueId}/tasks/${taskId}/cancel`, { + method: "POST", + }); + } + // Inbox async listInbox(): Promise { return this.fetch("/api/inbox"); diff --git a/apps/web/shared/types/events.ts b/apps/web/shared/types/events.ts index 304a28ee..8b3a5fc6 100644 --- a/apps/web/shared/types/events.ts +++ b/apps/web/shared/types/events.ts @@ -21,6 +21,7 @@ export type WSEventType = | "task:completed" | "task:failed" | "task:message" + | "task:cancelled" | "inbox:new" | "inbox:read" | "inbox:archived" @@ -179,6 +180,13 @@ export interface TaskFailedPayload { status: string; } +export interface TaskCancelledPayload { + task_id: string; + agent_id: string; + issue_id: string; + status: string; +} + export interface ReactionAddedPayload { reaction: Reaction; issue_id: string; diff --git a/server/cmd/server/router.go b/server/cmd/server/router.go index 400d3c40..a792b381 100644 --- a/server/cmd/server/router.go +++ b/server/cmd/server/router.go @@ -169,6 +169,7 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus) chi.Route r.Post("/subscribe", h.SubscribeToIssue) r.Post("/unsubscribe", h.UnsubscribeFromIssue) r.Get("/active-task", h.GetActiveTaskForIssue) + r.Post("/tasks/{taskId}/cancel", h.CancelTask) r.Get("/task-runs", h.ListTasksByIssue) r.Post("/reactions", h.AddIssueReaction) r.Delete("/reactions", h.RemoveIssueReaction) diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index 13f0c5da..6ab5ac93 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -682,7 +682,41 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { _ = d.client.ReportProgress(ctx, task.ID, fmt.Sprintf("Launching %s", provider), 1, 2) - result, err := d.runTask(ctx, task, provider, taskLog) + // Create a cancellable context so we can interrupt the running agent + // when the server-side task status changes to 'cancelled'. + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() + + // Poll for cancellation every 5 seconds while the task is running. + cancelledByPoll := make(chan struct{}) + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-runCtx.Done(): + return + case <-ticker.C: + if status, err := d.client.GetTaskStatus(ctx, task.ID); err == nil && status == "cancelled" { + taskLog.Info("task cancelled by server, interrupting agent") + runCancel() + close(cancelledByPoll) + return + } + } + } + }() + + result, err := d.runTask(runCtx, task, provider, taskLog) + + // Check if we were cancelled by the polling goroutine. + select { + case <-cancelledByPoll: + taskLog.Info("task cancelled during execution, discarding result") + return + default: + } + if err != nil { taskLog.Error("task failed", "error", err) if failErr := d.client.FailTask(ctx, task.ID, err.Error()); failErr != nil { diff --git a/server/internal/handler/daemon.go b/server/internal/handler/daemon.go index d08f080d..d0051766 100644 --- a/server/internal/handler/daemon.go +++ b/server/internal/handler/daemon.go @@ -525,6 +525,21 @@ func (h *Handler) GetActiveTaskForIssue(w http.ResponseWriter, r *http.Request) writeJSON(w, http.StatusOK, map[string]any{"task": taskToResponse(tasks[0])}) } +// CancelTask cancels a running or queued task by ID. +func (h *Handler) CancelTask(w http.ResponseWriter, r *http.Request) { + taskID := chi.URLParam(r, "taskId") + + task, err := h.TaskService.CancelTask(r.Context(), parseUUID(taskID)) + if err != nil { + slog.Warn("cancel task failed", "task_id", taskID, "error", err) + writeError(w, http.StatusBadRequest, err.Error()) + return + } + + slog.Info("task cancelled by user", "task_id", taskID, "issue_id", uuidToString(task.IssueID)) + writeJSON(w, http.StatusOK, taskToResponse(*task)) +} + // ListTasksByIssue returns all tasks (any status) for an issue — used for execution history. func (h *Handler) ListTasksByIssue(w http.ResponseWriter, r *http.Request) { issueID := chi.URLParam(r, "id") diff --git a/server/internal/service/task.go b/server/internal/service/task.go index b09d56d3..f46d0e14 100644 --- a/server/internal/service/task.go +++ b/server/internal/service/task.go @@ -104,6 +104,25 @@ func (s *TaskService) CancelTasksForIssue(ctx context.Context, issueID pgtype.UU return s.Queries.CancelAgentTasksByIssue(ctx, issueID) } +// CancelTask cancels a single task by ID. It broadcasts a task:cancelled event +// so frontends can update immediately. +func (s *TaskService) CancelTask(ctx context.Context, taskID pgtype.UUID) (*db.AgentTaskQueue, error) { + task, err := s.Queries.CancelAgentTask(ctx, taskID) + if err != nil { + return nil, fmt.Errorf("cancel task: %w", err) + } + + slog.Info("task cancelled", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(task.IssueID)) + + // Reconcile agent status + s.ReconcileAgentStatus(ctx, task.AgentID) + + // Broadcast cancellation as a task:failed event so frontends clear the live card + s.broadcastTaskEvent(ctx, protocol.EventTaskCancelled, task) + + return &task, nil +} + // ClaimTask atomically claims the next queued task for an agent, // respecting max_concurrent_tasks. func (s *TaskService) ClaimTask(ctx context.Context, agentID pgtype.UUID) (*db.AgentTaskQueue, error) { diff --git a/server/pkg/db/generated/agent.sql.go b/server/pkg/db/generated/agent.sql.go index a951d44e..befca00e 100644 --- a/server/pkg/db/generated/agent.sql.go +++ b/server/pkg/db/generated/agent.sql.go @@ -11,6 +11,37 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const cancelAgentTask = `-- name: CancelAgentTask :one +UPDATE agent_task_queue +SET status = 'cancelled', completed_at = now() +WHERE id = $1 AND status IN ('queued', 'dispatched', 'running') +RETURNING id, agent_id, issue_id, status, priority, dispatched_at, started_at, completed_at, result, error, created_at, context, runtime_id, session_id, work_dir, trigger_comment_id +` + +func (q *Queries) CancelAgentTask(ctx context.Context, id pgtype.UUID) (AgentTaskQueue, error) { + row := q.db.QueryRow(ctx, cancelAgentTask, id) + var i AgentTaskQueue + err := row.Scan( + &i.ID, + &i.AgentID, + &i.IssueID, + &i.Status, + &i.Priority, + &i.DispatchedAt, + &i.StartedAt, + &i.CompletedAt, + &i.Result, + &i.Error, + &i.CreatedAt, + &i.Context, + &i.RuntimeID, + &i.SessionID, + &i.WorkDir, + &i.TriggerCommentID, + ) + return i, err +} + const cancelAgentTasksByIssue = `-- name: CancelAgentTasksByIssue :exec UPDATE agent_task_queue SET status = 'cancelled' diff --git a/server/pkg/db/queries/agent.sql b/server/pkg/db/queries/agent.sql index 2b581204..4511200b 100644 --- a/server/pkg/db/queries/agent.sql +++ b/server/pkg/db/queries/agent.sql @@ -107,6 +107,12 @@ WHERE (status = 'dispatched' AND dispatched_at < now() - make_interval(secs => @ OR (status = 'running' AND started_at < now() - make_interval(secs => @running_timeout_secs::double precision)) RETURNING id, agent_id, issue_id; +-- name: CancelAgentTask :one +UPDATE agent_task_queue +SET status = 'cancelled', completed_at = now() +WHERE id = $1 AND status IN ('queued', 'dispatched', 'running') +RETURNING *; + -- name: CountRunningTasks :one SELECT count(*) FROM agent_task_queue WHERE agent_id = $1 AND status IN ('dispatched', 'running'); diff --git a/server/pkg/protocol/events.go b/server/pkg/protocol/events.go index 6d9ec027..7a8636ec 100644 --- a/server/pkg/protocol/events.go +++ b/server/pkg/protocol/events.go @@ -27,6 +27,7 @@ const ( EventTaskCompleted = "task:completed" EventTaskFailed = "task:failed" EventTaskMessage = "task:message" + EventTaskCancelled = "task:cancelled" // Inbox events EventInboxNew = "inbox:new"