diff --git a/server/cmd/multica/cmd_daemon.go b/server/cmd/multica/cmd_daemon.go index ca3939ab..69fdfaf5 100644 --- a/server/cmd/multica/cmd_daemon.go +++ b/server/cmd/multica/cmd_daemon.go @@ -61,6 +61,7 @@ func init() { f.Duration("poll-interval", 0, "Task poll interval (env: MULTICA_DAEMON_POLL_INTERVAL)") f.Duration("heartbeat-interval", 0, "Heartbeat interval (env: MULTICA_DAEMON_HEARTBEAT_INTERVAL)") f.Duration("agent-timeout", 0, "Per-task timeout (env: MULTICA_AGENT_TIMEOUT)") + f.Int("max-concurrent-tasks", 0, "Max tasks running in parallel (env: MULTICA_DAEMON_MAX_CONCURRENT_TASKS)") daemonLogsCmd.Flags().BoolP("follow", "f", false, "Follow log output") daemonLogsCmd.Flags().IntP("lines", "n", 50, "Number of lines to show") @@ -187,6 +188,9 @@ func buildDaemonStartArgs(cmd *cobra.Command) []string { if d, _ := cmd.Flags().GetDuration("agent-timeout"); d > 0 { args = append(args, "--agent-timeout", d.String()) } + if n, _ := cmd.Flags().GetInt("max-concurrent-tasks"); n > 0 { + args = append(args, "--max-concurrent-tasks", strconv.Itoa(n)) + } // Forward global persistent flags. if v, _ := cmd.Flags().GetString("server-url"); v != "" { @@ -212,6 +216,9 @@ func runDaemonForeground(cmd *cobra.Command) error { if d, _ := cmd.Flags().GetDuration("agent-timeout"); d > 0 { overrides.AgentTimeout = d } + if n, _ := cmd.Flags().GetInt("max-concurrent-tasks"); n > 0 { + overrides.MaxConcurrentTasks = n + } cfg, err := daemon.LoadConfig(overrides) if err != nil { diff --git a/server/cmd/server/runtime_sweeper.go b/server/cmd/server/runtime_sweeper.go index f3b637a6..2e58c0bf 100644 --- a/server/cmd/server/runtime_sweeper.go +++ b/server/cmd/server/runtime_sweeper.go @@ -12,16 +12,24 @@ import ( ) const ( - // sweepInterval is how often we check for stale runtimes. + // sweepInterval is how often we check for stale runtimes and tasks. sweepInterval = 30 * time.Second // staleThresholdSeconds marks runtimes offline if no heartbeat for this long. // The daemon heartbeat interval is 15s, so 45s = 3 missed heartbeats. staleThresholdSeconds = 45.0 + // dispatchTimeoutSeconds fails tasks stuck in 'dispatched' beyond this. + // The dispatched→running transition should be near-instant, so 5 minutes + // means something went wrong (e.g. StartTask API call failed silently). + dispatchTimeoutSeconds = 300.0 + // runningTimeoutSeconds fails tasks stuck in 'running' beyond this. + // The default agent timeout is 2h, so 2.5h gives a generous buffer. + runningTimeoutSeconds = 9000.0 ) // runRuntimeSweeper periodically marks runtimes as offline if their -// last_seen_at exceeds the stale threshold. This handles cases where the -// daemon crashes or is killed without calling the deregister endpoint. +// last_seen_at exceeds the stale threshold, and fails orphaned tasks. +// This handles cases where the daemon crashes, is killed without calling +// the deregister endpoint, or leaves tasks in a non-terminal state. func runRuntimeSweeper(ctx context.Context, queries *db.Queries, bus *events.Bus) { ticker := time.NewTicker(sweepInterval) defer ticker.Stop() @@ -31,55 +39,95 @@ func runRuntimeSweeper(ctx context.Context, queries *db.Queries, bus *events.Bus case <-ctx.Done(): return case <-ticker.C: - staleRows, err := queries.MarkStaleRuntimesOffline(ctx, staleThresholdSeconds) - if err != nil { - slog.Warn("runtime sweeper: failed to mark stale runtimes offline", "error", err) - continue - } - if len(staleRows) == 0 { - continue - } - - // Collect unique workspace IDs to notify. - workspaces := make(map[string]bool) - for _, row := range staleRows { - wsID := util.UUIDToString(row.WorkspaceID) - workspaces[wsID] = true - } - - slog.Info("runtime sweeper: marked stale runtimes offline", "count", len(staleRows), "workspaces", len(workspaces)) - - // Fail orphaned tasks (dispatched/running) whose runtimes just went offline. - failedTasks, err := queries.FailTasksForOfflineRuntimes(ctx) - if err != nil { - slog.Warn("runtime sweeper: failed to clean up stale tasks", "error", err) - } else if len(failedTasks) > 0 { - slog.Info("runtime sweeper: failed orphaned tasks", "count", len(failedTasks)) - for _, ft := range failedTasks { - bus.Publish(events.Event{ - Type: protocol.EventTaskFailed, - ActorType: "system", - Payload: map[string]any{ - "task_id": util.UUIDToString(ft.ID), - "agent_id": util.UUIDToString(ft.AgentID), - "issue_id": util.UUIDToString(ft.IssueID), - "status": "failed", - }, - }) - } - } - - // Notify frontend clients so they re-fetch runtime list. - for wsID := range workspaces { - bus.Publish(events.Event{ - Type: protocol.EventDaemonRegister, - WorkspaceID: wsID, - ActorType: "system", - Payload: map[string]any{ - "action": "stale_sweep", - }, - }) - } + sweepStaleRuntimes(ctx, queries, bus) + sweepStaleTasks(ctx, queries, bus) } } } + +// sweepStaleRuntimes marks runtimes offline if they haven't heartbeated, +// then fails any tasks belonging to those offline runtimes. +func sweepStaleRuntimes(ctx context.Context, queries *db.Queries, bus *events.Bus) { + staleRows, err := queries.MarkStaleRuntimesOffline(ctx, staleThresholdSeconds) + if err != nil { + slog.Warn("runtime sweeper: failed to mark stale runtimes offline", "error", err) + return + } + if len(staleRows) == 0 { + return + } + + // Collect unique workspace IDs to notify. + workspaces := make(map[string]bool) + for _, row := range staleRows { + wsID := util.UUIDToString(row.WorkspaceID) + workspaces[wsID] = true + } + + slog.Info("runtime sweeper: marked stale runtimes offline", "count", len(staleRows), "workspaces", len(workspaces)) + + // Fail orphaned tasks (dispatched/running) whose runtimes just went offline. + failedTasks, err := queries.FailTasksForOfflineRuntimes(ctx) + if err != nil { + slog.Warn("runtime sweeper: failed to clean up stale tasks", "error", err) + } else if len(failedTasks) > 0 { + slog.Info("runtime sweeper: failed orphaned tasks", "count", len(failedTasks)) + for _, ft := range failedTasks { + bus.Publish(events.Event{ + Type: protocol.EventTaskFailed, + ActorType: "system", + Payload: map[string]any{ + "task_id": util.UUIDToString(ft.ID), + "agent_id": util.UUIDToString(ft.AgentID), + "issue_id": util.UUIDToString(ft.IssueID), + "status": "failed", + }, + }) + } + } + + // Notify frontend clients so they re-fetch runtime list. + for wsID := range workspaces { + bus.Publish(events.Event{ + Type: protocol.EventDaemonRegister, + WorkspaceID: wsID, + ActorType: "system", + Payload: map[string]any{ + "action": "stale_sweep", + }, + }) + } +} + +// sweepStaleTasks fails tasks stuck in dispatched/running for too long, +// even when the runtime is still online. This handles cases where: +// - The agent process hangs and the daemon is still heartbeating +// - The daemon failed to report task completion/failure +// - A server restart left tasks in a non-terminal state +func sweepStaleTasks(ctx context.Context, queries *db.Queries, bus *events.Bus) { + failedTasks, err := queries.FailStaleTasks(ctx, db.FailStaleTasksParams{ + DispatchTimeoutSecs: dispatchTimeoutSeconds, + RunningTimeoutSecs: runningTimeoutSeconds, + }) + if err != nil { + slog.Warn("task sweeper: failed to clean up stale tasks", "error", err) + return + } + if len(failedTasks) == 0 { + return + } + + slog.Info("task sweeper: failed stale tasks", "count", len(failedTasks)) + for _, ft := range failedTasks { + bus.Publish(events.Event{ + Type: protocol.EventTaskFailed, + ActorType: "system", + Payload: map[string]any{ + "task_id": util.UUIDToString(ft.ID), + "agent_id": util.UUIDToString(ft.AgentID), + "issue_id": util.UUIDToString(ft.IssueID), + "status": "failed", + }, + }) + } +} diff --git a/server/internal/daemon/config.go b/server/internal/daemon/config.go index 37d444b1..f9e942ad 100644 --- a/server/internal/daemon/config.go +++ b/server/internal/daemon/config.go @@ -18,34 +18,37 @@ const ( DefaultRuntimeName = "Local Agent" DefaultConfigReloadInterval = 5 * time.Second DefaultHealthPort = 19514 + DefaultMaxConcurrentTasks = 20 ) // Config holds all daemon configuration. type Config struct { - ServerBaseURL string - DaemonID string - DeviceName string - RuntimeName string - Agents map[string]AgentEntry // "claude" -> entry, "codex" -> entry - WorkspacesRoot string // base path for execution envs (default: ~/multica_workspaces) - KeepEnvAfterTask bool // preserve env after task for debugging - HealthPort int // local HTTP port for health checks (default: 19514) - PollInterval time.Duration - HeartbeatInterval time.Duration - AgentTimeout time.Duration + ServerBaseURL string + DaemonID string + DeviceName string + RuntimeName string + Agents map[string]AgentEntry // "claude" -> entry, "codex" -> entry + WorkspacesRoot string // base path for execution envs (default: ~/multica_workspaces) + KeepEnvAfterTask bool // preserve env after task for debugging + HealthPort int // local HTTP port for health checks (default: 19514) + MaxConcurrentTasks int // max tasks running in parallel (default: 20) + PollInterval time.Duration + HeartbeatInterval time.Duration + AgentTimeout time.Duration } // Overrides allows CLI flags to override environment variables and defaults. // Zero values are ignored and the env/default value is used instead. type Overrides struct { - ServerURL string - WorkspacesRoot string - PollInterval time.Duration - HeartbeatInterval time.Duration - AgentTimeout time.Duration - DaemonID string - DeviceName string - RuntimeName string + ServerURL string + WorkspacesRoot string + PollInterval time.Duration + HeartbeatInterval time.Duration + AgentTimeout time.Duration + MaxConcurrentTasks int + DaemonID string + DeviceName string + RuntimeName string } // LoadConfig builds the daemon configuration from environment variables @@ -112,6 +115,14 @@ func LoadConfig(overrides Overrides) (Config, error) { agentTimeout = overrides.AgentTimeout } + maxConcurrentTasks, err := intFromEnv("MULTICA_DAEMON_MAX_CONCURRENT_TASKS", DefaultMaxConcurrentTasks) + if err != nil { + return Config{}, err + } + if overrides.MaxConcurrentTasks > 0 { + maxConcurrentTasks = overrides.MaxConcurrentTasks + } + // String overrides daemonID := envOrDefault("MULTICA_DAEMON_ID", host) if overrides.DaemonID != "" { @@ -149,17 +160,18 @@ func LoadConfig(overrides Overrides) (Config, error) { keepEnv := os.Getenv("MULTICA_KEEP_ENV_AFTER_TASK") == "true" || os.Getenv("MULTICA_KEEP_ENV_AFTER_TASK") == "1" return Config{ - ServerBaseURL: serverBaseURL, - DaemonID: daemonID, - DeviceName: deviceName, - RuntimeName: runtimeName, - Agents: agents, - WorkspacesRoot: workspacesRoot, - KeepEnvAfterTask: keepEnv, - HealthPort: DefaultHealthPort, - PollInterval: pollInterval, - HeartbeatInterval: heartbeatInterval, - AgentTimeout: agentTimeout, + ServerBaseURL: serverBaseURL, + DaemonID: daemonID, + DeviceName: deviceName, + RuntimeName: runtimeName, + Agents: agents, + WorkspacesRoot: workspacesRoot, + KeepEnvAfterTask: keepEnv, + HealthPort: DefaultHealthPort, + MaxConcurrentTasks: maxConcurrentTasks, + PollInterval: pollInterval, + HeartbeatInterval: heartbeatInterval, + AgentTimeout: agentTimeout, }, nil } diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index e7b71fc6..90bf863b 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -503,11 +503,22 @@ func (d *Daemon) usageScanLoop(ctx context.Context) { } func (d *Daemon) pollLoop(ctx context.Context) error { + sem := make(chan struct{}, d.cfg.MaxConcurrentTasks) + var wg sync.WaitGroup + pollOffset := 0 pollCount := 0 for { select { case <-ctx.Done(): + d.logger.Info("poll loop stopping, waiting for in-flight tasks", "max_wait", "30s") + waitDone := make(chan struct{}) + go func() { wg.Wait(); close(waitDone) }() + select { + case <-waitDone: + case <-time.After(30 * time.Second): + d.logger.Warn("timed out waiting for in-flight tasks") + } return ctx.Err() default: } @@ -515,6 +526,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error { runtimeIDs := d.allRuntimeIDs() if len(runtimeIDs) == 0 { if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil { + wg.Wait() return err } continue @@ -523,21 +535,40 @@ func (d *Daemon) pollLoop(ctx context.Context) error { claimed := false n := len(runtimeIDs) for i := 0; i < n; i++ { + // Check if we have capacity before claiming. + select { + case sem <- struct{}{}: + // Acquired a slot. + default: + // All slots occupied, stop trying to claim. + d.logger.Debug("poll: at capacity", "running", d.cfg.MaxConcurrentTasks) + goto sleep + } + rid := runtimeIDs[(pollOffset+i)%n] task, err := d.client.ClaimTask(ctx, rid) if err != nil { + <-sem // Release the slot. d.logger.Warn("claim task failed", "runtime_id", rid, "error", err) continue } if task != nil { d.logger.Info("task received", "task_id", task.ID, "issue_id", task.IssueID) - d.handleTask(ctx, *task) + wg.Add(1) + go func(t Task) { + defer wg.Done() + defer func() { <-sem }() + d.handleTask(ctx, t) + }(*task) claimed = true pollOffset = (pollOffset + i + 1) % n break } + // No task for this runtime, release the slot and try next. + <-sem } + sleep: if !claimed { pollCount++ if pollCount%20 == 1 { @@ -545,6 +576,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error { } pollOffset = (pollOffset + 1) % n if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil { + wg.Wait() return err } } else { @@ -562,6 +594,9 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { if err := d.client.StartTask(ctx, task.ID); err != nil { d.logger.Error("start task failed", "task_id", task.ID, "error", err) + if failErr := d.client.FailTask(ctx, task.ID, fmt.Sprintf("start task failed: %s", err.Error())); failErr != nil { + d.logger.Error("fail task after start error", "task_id", task.ID, "error", failErr) + } return } @@ -594,7 +629,10 @@ func (d *Daemon) handleTask(ctx context.Context, task Task) { default: d.logger.Info("task completed", "task_id", task.ID, "status", result.Status) if err := d.client.CompleteTask(ctx, task.ID, result.Comment, result.BranchName, result.SessionID, result.WorkDir); err != nil { - d.logger.Error("complete task failed", "task_id", task.ID, "error", err) + d.logger.Error("complete task failed, falling back to fail", "task_id", task.ID, "error", err) + if failErr := d.client.FailTask(ctx, task.ID, fmt.Sprintf("complete task failed: %s", err.Error())); failErr != nil { + d.logger.Error("fail task fallback also failed", "task_id", task.ID, "error", failErr) + } } } } diff --git a/server/internal/daemon/helpers.go b/server/internal/daemon/helpers.go index a7de9b9e..f3e20d77 100644 --- a/server/internal/daemon/helpers.go +++ b/server/internal/daemon/helpers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strconv" "strings" "time" ) @@ -28,6 +29,18 @@ func durationFromEnv(key string, fallback time.Duration) (time.Duration, error) return d, nil } +func intFromEnv(key string, fallback int) (int, error) { + value := strings.TrimSpace(os.Getenv(key)) + if value == "" { + return fallback, nil + } + n, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("%s: invalid integer %q: %w", key, value, err) + } + return n, nil +} + func sleepWithContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer timer.Stop() diff --git a/server/internal/handler/agent.go b/server/internal/handler/agent.go index 9b6f6f1f..d5303285 100644 --- a/server/internal/handler/agent.go +++ b/server/internal/handler/agent.go @@ -251,7 +251,7 @@ func (h *Handler) CreateAgent(w http.ResponseWriter, r *http.Request) { req.Visibility = "workspace" } if req.MaxConcurrentTasks == 0 { - req.MaxConcurrentTasks = 1 + req.MaxConcurrentTasks = 6 } runtime, err := h.Queries.GetAgentRuntimeForWorkspace(r.Context(), db.GetAgentRuntimeForWorkspaceParams{ diff --git a/server/internal/service/task.go b/server/internal/service/task.go index 9473dccb..d7da701b 100644 --- a/server/internal/service/task.go +++ b/server/internal/service/task.go @@ -209,7 +209,7 @@ func (s *TaskService) FailTask(ctx context.Context, taskID pgtype.UUID, errMsg s }) if err != nil { if existing, lookupErr := s.Queries.GetAgentTask(ctx, taskID); lookupErr == nil { - slog.Warn("fail task failed: task not in running state", + slog.Warn("fail task failed: task not in dispatched/running state", "task_id", util.UUIDToString(taskID), "current_status", existing.Status, "issue_id", util.UUIDToString(existing.IssueID), diff --git a/server/migrations/023_agent_concurrency_default.down.sql b/server/migrations/023_agent_concurrency_default.down.sql new file mode 100644 index 00000000..6b7b43b5 --- /dev/null +++ b/server/migrations/023_agent_concurrency_default.down.sql @@ -0,0 +1 @@ +ALTER TABLE agent ALTER COLUMN max_concurrent_tasks SET DEFAULT 1; diff --git a/server/migrations/023_agent_concurrency_default.up.sql b/server/migrations/023_agent_concurrency_default.up.sql new file mode 100644 index 00000000..1a9acd94 --- /dev/null +++ b/server/migrations/023_agent_concurrency_default.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE agent ALTER COLUMN max_concurrent_tasks SET DEFAULT 6; +UPDATE agent SET max_concurrent_tasks = 6 WHERE max_concurrent_tasks = 1; diff --git a/server/pkg/db/generated/agent.sql.go b/server/pkg/db/generated/agent.sql.go index d3d5edc8..0679cfdf 100644 --- a/server/pkg/db/generated/agent.sql.go +++ b/server/pkg/db/generated/agent.sql.go @@ -229,7 +229,7 @@ func (q *Queries) DeleteAgent(ctx context.Context, id pgtype.UUID) error { const failAgentTask = `-- name: FailAgentTask :one UPDATE agent_task_queue SET status = 'failed', completed_at = now(), error = $2 -WHERE id = $1 AND status = 'running' +WHERE id = $1 AND status IN ('dispatched', 'running') RETURNING id, agent_id, issue_id, status, priority, dispatched_at, started_at, completed_at, result, error, created_at, context, runtime_id, session_id, work_dir ` @@ -261,6 +261,48 @@ func (q *Queries) FailAgentTask(ctx context.Context, arg FailAgentTaskParams) (A return i, err } +const failStaleTasks = `-- name: FailStaleTasks :many +UPDATE agent_task_queue +SET status = 'failed', completed_at = now(), error = 'task timed out' +WHERE (status = 'dispatched' AND dispatched_at < now() - make_interval(secs => $1::double precision)) + OR (status = 'running' AND started_at < now() - make_interval(secs => $2::double precision)) +RETURNING id, agent_id, issue_id +` + +type FailStaleTasksParams struct { + DispatchTimeoutSecs float64 `json:"dispatch_timeout_secs"` + RunningTimeoutSecs float64 `json:"running_timeout_secs"` +} + +type FailStaleTasksRow struct { + ID pgtype.UUID `json:"id"` + AgentID pgtype.UUID `json:"agent_id"` + IssueID pgtype.UUID `json:"issue_id"` +} + +// Fails tasks stuck in dispatched/running beyond the given thresholds. +// Handles cases where the daemon is alive but the task is orphaned +// (e.g. agent process hung, daemon failed to report completion). +func (q *Queries) FailStaleTasks(ctx context.Context, arg FailStaleTasksParams) ([]FailStaleTasksRow, error) { + rows, err := q.db.Query(ctx, failStaleTasks, arg.DispatchTimeoutSecs, arg.RunningTimeoutSecs) + if err != nil { + return nil, err + } + defer rows.Close() + items := []FailStaleTasksRow{} + for rows.Next() { + var i FailStaleTasksRow + if err := rows.Scan(&i.ID, &i.AgentID, &i.IssueID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAgent = `-- name: GetAgent :one SELECT id, workspace_id, name, avatar_url, runtime_mode, runtime_config, visibility, status, max_concurrent_tasks, owner_id, created_at, updated_at, description, tools, triggers, runtime_id, instructions FROM agent WHERE id = $1 diff --git a/server/pkg/db/queries/agent.sql b/server/pkg/db/queries/agent.sql index 70542292..72084294 100644 --- a/server/pkg/db/queries/agent.sql +++ b/server/pkg/db/queries/agent.sql @@ -90,9 +90,19 @@ LIMIT 1; -- name: FailAgentTask :one UPDATE agent_task_queue SET status = 'failed', completed_at = now(), error = $2 -WHERE id = $1 AND status = 'running' +WHERE id = $1 AND status IN ('dispatched', 'running') RETURNING *; +-- name: FailStaleTasks :many +-- Fails tasks stuck in dispatched/running beyond the given thresholds. +-- Handles cases where the daemon is alive but the task is orphaned +-- (e.g. agent process hung, daemon failed to report completion). +UPDATE agent_task_queue +SET status = 'failed', completed_at = now(), error = 'task timed out' +WHERE (status = 'dispatched' AND dispatched_at < now() - make_interval(secs => @dispatch_timeout_secs::double precision)) + OR (status = 'running' AND started_at < now() - make_interval(secs => @running_timeout_secs::double precision)) +RETURNING id, agent_id, issue_id; + -- name: CountRunningTasks :one SELECT count(*) FROM agent_task_queue WHERE agent_id = $1 AND status IN ('dispatched', 'running');