Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions internal/types/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
type MessageType string

const (
MessageTypeTaskAssignment MessageType = "task_assignment"
MessageTypeTaskClaimed MessageType = "task_claimed"
MessageTypeTaskCompleted MessageType = "task_completed"
MessageTypeTaskFailed MessageType = "task_failed"
MessageTypeTaskRejected MessageType = "task_rejected"
MessageTypeHeartbeat MessageType = "heartbeat"
MessageTypeTaskAssignment MessageType = "task_assignment"
MessageTypeTaskClaimed MessageType = "task_claimed"
MessageTypeTaskCompleted MessageType = "task_completed"
MessageTypeTaskFailed MessageType = "task_failed"
MessageTypeTaskRejected MessageType = "task_rejected"
MessageTypeTaskCancellation MessageType = "task_cancellation"
MessageTypeHeartbeat MessageType = "heartbeat"
)

// WebSocketMessage is the base structure for all WebSocket messages
Expand Down Expand Up @@ -51,14 +52,16 @@ type TaskClaimedMessage struct {

// TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit.
type TaskCompletedMessage struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
TaskID string `json:"task_id"`
Message string `json:"message"`
TaskState *TaskState `json:"task_state,omitempty"`
}

// TaskFailedMessage is sent from worker to server if task launch fails
type TaskFailedMessage struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
TaskID string `json:"task_id"`
Message string `json:"message"`
TaskState *TaskState `json:"task_state,omitempty"`
}

// TaskRejectedMessage is sent from worker to server when the worker cannot accept the task
Expand All @@ -68,6 +71,18 @@ type TaskRejectedMessage struct {
Reason string `json:"reason"`
}

// TaskCancellationMessage is sent from server to worker to cancel an active task.
type TaskCancellationMessage struct {
TaskID string `json:"task_id"`
}

// TaskState is the serialized terminal task state accepted by warp-server.
type TaskState string

const (
TaskStateCancelled TaskState = "CANCELLED"
)

type TaskDefinition struct {
Prompt string `json:"prompt"`
}
Expand Down
6 changes: 6 additions & 0 deletions internal/worker/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err

log.Infof(ctx, "Running setup command: %s", b.config.SetupCommand)
if err := b.runCommand(ctx, b.config.SetupCommand, workspaceDir, setupEnv); err != nil {
if ctx.Err() != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err())
}
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonSetupCommand, fmt.Errorf("setup command failed: %w", err))
}
}
Expand Down Expand Up @@ -183,6 +186,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err
log.Debugf(ctx, "Command: %s %s", b.ozPath, strings.Join(params.BaseArgs, " "))

if err := cmd.Run(); err != nil {
if ctx.Err() != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err())
}
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonAgentInvocation, fmt.Errorf("oz agent exited with error: %w", err))
}

Expand Down
4 changes: 3 additions & 1 deletion internal/worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ func (b *DockerBackend) ExecuteTask(ctx context.Context, params *TaskParams) err

defer func() {
if containerID != "" && !b.config.NoCleanup {
if removeErr := dockerClient.ContainerRemove(ctx, containerID, container.RemoveOptions{Force: true}); removeErr != nil {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), BackendShutdownTimeout)
defer cleanupCancel()
if removeErr := dockerClient.ContainerRemove(cleanupCtx, containerID, container.RemoveOptions{Force: true}); removeErr != nil {
log.Debugf(ctx, "Container %s already removed or removal failed: %v", containerID, removeErr)
}
}
Expand Down
85 changes: 80 additions & 5 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"sync"
Expand Down Expand Up @@ -59,12 +60,17 @@ type Worker struct {
reconnectDelay time.Duration
lastHeartbeat time.Time
sendChan chan []byte
activeTasks map[string]context.CancelFunc
activeTasks map[string]activeTask
tasksMutex sync.Mutex
backend Backend
taskSemaphore *semaphore.Weighted // nil when unlimited
}

type activeTask struct {
ctx context.Context
cancel context.CancelFunc
}

func New(ctx context.Context, config Config) (*Worker, error) {
workerCtx, cancel := context.WithCancel(ctx)

Expand Down Expand Up @@ -109,7 +115,7 @@ func New(ctx context.Context, config Config) (*Worker, error) {
cancel: cancel,
reconnectDelay: InitialReconnectDelay,
sendChan: make(chan []byte, 256),
activeTasks: make(map[string]context.CancelFunc),
activeTasks: make(map[string]activeTask),
backend: backend,
taskSemaphore: taskSemaphore,
}, nil
Expand Down Expand Up @@ -322,11 +328,36 @@ func (w *Worker) handleMessage(message []byte) {
}
w.handleTaskAssignment(&assignment)

case types.MessageTypeTaskCancellation:
var cancellation types.TaskCancellationMessage
if err := json.Unmarshal(msg.Data, &cancellation); err != nil {
log.Errorf(w.ctx, "Failed to unmarshal task cancellation: %v", err)
return
}
w.handleTaskCancellation(&cancellation)

default:
log.Warnf(w.ctx, "Unknown message type: %s", msg.Type)
}
}

func (w *Worker) handleTaskCancellation(cancellation *types.TaskCancellationMessage) {
w.tasksMutex.Lock()
task, ok := w.activeTasks[cancellation.TaskID]
w.tasksMutex.Unlock()
if !ok {
log.Warnf(w.ctx, "Received cancellation for inactive task: taskID=%s", cancellation.TaskID)
return
}

log.Infof(w.ctx, "Cancelling task from server request: taskID=%s", cancellation.TaskID)
metrics.AddTaskEvent(task.ctx, "task.cancellation_requested",
attribute.String("source", "server"),
attribute.String("task.id", cancellation.TaskID),
)
task.cancel()
}

func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {
receivedAt := time.Now()
log.Infof(w.ctx, "Received task assignment: taskID=%s, title=%s", assignment.TaskID, assignment.Task.Title)
Expand Down Expand Up @@ -364,7 +395,10 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {
taskCtx, taskCancel := context.WithCancel(taskCtx)

w.tasksMutex.Lock()
w.activeTasks[assignment.TaskID] = taskCancel
w.activeTasks[assignment.TaskID] = activeTask{
ctx: taskCtx,
cancel: taskCancel,
}
w.tasksMutex.Unlock()
go w.executeTask(taskCtx, span, assignment, receivedAt)
}
Expand Down Expand Up @@ -489,6 +523,17 @@ func (w *Worker) executeTask(ctx context.Context, span trace.Span, assignment *t

err := w.backend.ExecuteTask(ctx, params)
if err != nil {
if errors.Is(err, context.Canceled) {
result = metrics.TaskResultCancelled
metrics.AddTaskEvent(ctx, "task.cancelled")
span.SetStatus(codes.Ok, "task cancelled")
log.Infof(ctx, "Task execution cancelled: taskID=%s", taskID)
if statusErr := w.sendTaskCancelled(taskID, "Task cancelled."); statusErr != nil {
log.Errorf(ctx, "Failed to send task cancelled message: %v", statusErr)
}
return
}

result = metrics.TaskResultFailed
phase, reason := taskFailureLabels(err)
metrics.RecordTaskFailure(phase, reason)
Expand Down Expand Up @@ -538,6 +583,32 @@ func (w *Worker) sendTaskClaimed(taskID string) error {
return w.sendMessage(msgBytes)
}

func (w *Worker) sendTaskCancelled(taskID, message string) error {
taskState := types.TaskStateCancelled
completedMsg := types.TaskCompletedMessage{
TaskID: taskID,
Message: message,
TaskState: &taskState,
}

data, err := json.Marshal(completedMsg)
if err != nil {
return fmt.Errorf("failed to marshal task cancelled message: %w", err)
}

msg := types.WebSocketMessage{
Type: types.MessageTypeTaskCompleted,
Data: data,
}

msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal websocket message: %w", err)
}

return w.sendMessage(msgBytes)
}

func (w *Worker) sendTaskRejected(taskID, reason string) error {
rejectedMsg := types.TaskRejectedMessage{
TaskID: taskID,
Expand Down Expand Up @@ -628,9 +699,13 @@ func (w *Worker) Shutdown() {
activeTaskCount := len(w.activeTasks)
if activeTaskCount > 0 {
log.Infof(w.ctx, "Cancelling %d active tasks", activeTaskCount)
for taskID, cancel := range w.activeTasks {
for taskID, task := range w.activeTasks {
log.Debugf(w.ctx, "Cancelling task: %s", taskID)
cancel()
metrics.AddTaskEvent(task.ctx, "task.cancellation_requested",
attribute.String("source", "signal"),
attribute.String("task.id", taskID),
)
task.cancel()
}
}
w.tasksMutex.Unlock()
Expand Down
74 changes: 71 additions & 3 deletions internal/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,80 @@
}
}

func TestExecuteTaskReportsTaskCancelledOnContextCancellation(t *testing.T) {
w := &Worker{
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 1),
activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}},
backend: &recordingBackend{err: context.Canceled},
}

w.executeTask(context.Background(), trace.SpanFromContext(context.Background()), &types.TaskAssignmentMessage{
TaskID: "task-1",
Task: &types.Task{ID: "task-1", Title: "test task"},
}, time.Now())

msg := readWebSocketMessage(t, w.sendChan)
if msg.Type != types.MessageTypeTaskCompleted {
t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskCompleted)
}

var completed types.TaskCompletedMessage
if err := json.Unmarshal(msg.Data, &completed); err != nil {
t.Fatalf("failed to unmarshal task completed message: %v", err)
}
if completed.TaskID != "task-1" {
t.Errorf("task ID = %q, want %q", completed.TaskID, "task-1")
}
if completed.TaskState == nil || *completed.TaskState != types.TaskStateCancelled {
t.Fatalf("task state = %v, want %q", completed.TaskState, types.TaskStateCancelled)
}
if _, ok := w.activeTasks["task-1"]; ok {
t.Fatal("task should be removed from active tasks")
}
}

func TestHandleMessageCancelsActiveTask(t *testing.T) {
taskCtx, taskCancel := context.WithCancel(context.Background())
defer taskCancel()

w := &Worker{
ctx: context.Background(),
sendChan: make(chan []byte, 1),
activeTasks: map[string]activeTask{
"task-1": {
ctx: taskCtx,
cancel: taskCancel,
},
},
}

data, err := json.Marshal(types.TaskCancellationMessage{TaskID: "task-1"})
if err != nil {
t.Fatalf("failed to marshal cancellation message: %v", err)
}
message, err := json.Marshal(types.WebSocketMessage{
Type: types.MessageTypeTaskCancellation,
Data: data,
})
if err != nil {
t.Fatalf("failed to marshal websocket message: %v", err)
}

w.handleMessage(message)

if taskCtx.Err() != context.Canceled {
t.Fatalf("task context error = %v, want %v", taskCtx.Err(), context.Canceled)
}
}

func TestExecuteTaskReportsTaskCompletedOnSuccess(t *testing.T) {
w := &Worker{
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 1),
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}},
backend: &recordingBackend{},
}

Expand Down Expand Up @@ -123,7 +191,7 @@
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 1),
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}},
backend: &recordingBackend{err: errors.New("boom")},
}

Expand Down Expand Up @@ -157,7 +225,7 @@

select {
case msgBytes := <-messages:
var msg types.WebSocketMessage

Check failure on line 228 in internal/worker/worker_test.go

View workflow job for this annotation

GitHub Actions / test

cannot use map[string]context.CancelFunc{…} (value of type map[string]context.CancelFunc) as map[string]activeTask value in struct literal
if err := json.Unmarshal(msgBytes, &msg); err != nil {
t.Fatalf("failed to unmarshal websocket message: %v", err)
}
Expand Down Expand Up @@ -186,7 +254,7 @@

envID := "env-123"

t.Run("server-provided image wins over default_image", func(t *testing.T) {

Check failure on line 257 in internal/worker/worker_test.go

View workflow job for this annotation

GitHub Actions / test

cannot use map[string]context.CancelFunc{…} (value of type map[string]context.CancelFunc) as map[string]activeTask value in struct literal
w := newWorker("my-registry.io/default:v1")
got := w.defaultImageForTask("server-image:latest", &types.Task{})
if got != "server-image:latest" {
Expand Down Expand Up @@ -363,7 +431,7 @@
w := &Worker{
ctx: workerCtx,
cancel: cancel,
activeTasks: make(map[string]context.CancelFunc),
activeTasks: make(map[string]activeTask),
backend: backend,
}

Expand Down
Loading