Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Example config:

```yaml
worker_id: "my-worker"
task_timeout: "2h"
backend:
direct:
workspace_root: "/var/lib/oz/workspaces"
Expand All @@ -57,6 +58,7 @@ Example config:

```yaml
worker_id: "my-worker"
task_timeout: "2h"
backend:
kubernetes:
kubeconfig: "/path/to/kubeconfig"
Expand All @@ -76,6 +78,7 @@ backend:

Notes:

- `task_timeout` is a top-level worker setting that limits the wall-clock runtime of each task across Docker, direct, and Kubernetes backends; omit it or set `0s` for unlimited runtime
- `default_image` sets the Docker image for task Jobs when no Warp environment is configured on the run; this lets you skip creating a Warp environment entirely if all your tasks use the same base image (precedence: Warp environment image > `default_image` > `ubuntu:22.04`)
- `namespace` selects the namespace inside the chosen cluster; it does not choose the cluster itself, and defaults to `default` when omitted
- `unschedulable_timeout` controls how long a Pod may remain unschedulable before the task is failed early; it defaults to `30s`, and `0s` disables that fail-fast behavior
Expand Down
3 changes: 3 additions & 0 deletions charts/oz-agent-worker/templates/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ data:
worker_id: {{ required "worker.workerId is required" .Values.worker.workerId | quote }}
cleanup: {{ .Values.worker.cleanup }}
max_concurrent_tasks: {{ .Values.worker.maxConcurrentTasks }}
{{- if .Values.worker.taskTimeout }}
task_timeout: {{ .Values.worker.taskTimeout | quote }}
{{- end }}
{{- if .Values.worker.idleOnComplete }}
idle_on_complete: {{ .Values.worker.idleOnComplete | quote }}
{{- end }}
Expand Down
3 changes: 3 additions & 0 deletions charts/oz-agent-worker/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ worker:
logLevel: info
cleanup: true
maxConcurrentTasks: 0
# Maximum wall-clock runtime for each task. Empty or "0s" means unlimited.
# Examples: "2h", "90m", "30m".
taskTimeout: ""
idleOnComplete: ""
extraArgs: []
extraEnv: []
Expand Down
4 changes: 4 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type FileConfig struct {
WorkerID string `yaml:"worker_id"`
Cleanup *bool `yaml:"cleanup"`
MaxConcurrentTasks *int `yaml:"max_concurrent_tasks"`
// TaskTimeout controls the maximum wall-clock runtime for each task.
// Uses Go duration format (e.g. "2h", "90m", "0s"). When nil or "0s",
// tasks may run indefinitely unless another backend-specific timeout applies.
TaskTimeout *string `yaml:"task_timeout"`
// IdleOnComplete controls how long the oz CLI process stays alive after a task's
// conversation finishes, to allow follow-up interactions via the shared session.
// Uses humantime format (e.g. "45m", "10m", "0s"). When nil, the oz CLI default
Expand Down
32 changes: 32 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,35 @@ worker_id: "test"
}
})
}

func TestLoadTaskTimeout(t *testing.T) {
t.Run("parses task_timeout when set", func(t *testing.T) {
path := writeTestConfig(t, `
worker_id: "test"
task_timeout: "2h"
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.TaskTimeout == nil {
t.Fatal("expected task_timeout to be set")
}
if *cfg.TaskTimeout != "2h" {
t.Errorf("task_timeout = %q, want %q", *cfg.TaskTimeout, "2h")
}
})

t.Run("task_timeout is nil when not set", func(t *testing.T) {
path := writeTestConfig(t, `
worker_id: "test"
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.TaskTimeout != nil {
t.Errorf("expected task_timeout to be nil, got %q", *cfg.TaskTimeout)
}
})
}
12 changes: 11 additions & 1 deletion internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Config struct {
LogLevel string
BackendType string // "docker", "direct", or "kubernetes"
MaxConcurrentTasks int // 0 means unlimited
// TaskTimeout is the maximum wall-clock runtime for each task. 0 means unlimited.
TaskTimeout time.Duration
// IdleOnComplete is passed to the oz CLI's --idle-on-complete flag for every task.
// Empty string means use the oz CLI default (45m). Use "0s" to disable idle.
IdleOnComplete string
Expand Down Expand Up @@ -361,7 +363,15 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) {
metrics.RecordTaskClaim()
metrics.AddTaskEvent(taskCtx, "task.claimed")
metrics.IncTasksActive()
taskCtx, taskCancel := context.WithCancel(taskCtx)
var taskCancel context.CancelFunc
if w.config.TaskTimeout > 0 {
taskCtx, taskCancel = context.WithTimeout(taskCtx, w.config.TaskTimeout)
metrics.AddTaskEvent(taskCtx, "task.timeout_configured",
attribute.Int64("timeout.ms", w.config.TaskTimeout.Milliseconds()),
)
} else {
taskCtx, taskCancel = context.WithCancel(taskCtx)
}

w.tasksMutex.Lock()
w.activeTasks[assignment.TaskID] = taskCancel
Expand Down
79 changes: 78 additions & 1 deletion internal/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/warpdotdev/oz-agent-worker/internal/metrics"
"testing"
"time"

"github.com/warpdotdev/oz-agent-worker/internal/metrics"
"github.com/warpdotdev/oz-agent-worker/internal/types"
"go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -37,6 +37,19 @@ func (b *recordingBackend) ExecuteTask(context.Context, *TaskParams) error {

func (b *recordingBackend) Shutdown(context.Context) {}

type contextRecordingBackend struct {
ctx context.Context
called chan struct{}
}

func (b *contextRecordingBackend) ExecuteTask(ctx context.Context, _ *TaskParams) error {
b.ctx = ctx
close(b.called)
return nil
}

func (b *contextRecordingBackend) Shutdown(context.Context) {}

func TestTaskFailureLabels(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -152,6 +165,70 @@ func TestExecuteTaskReportsTaskFailedOnBackendError(t *testing.T) {
}
}

func TestHandleTaskAssignmentAppliesTaskTimeout(t *testing.T) {
backend := &contextRecordingBackend{called: make(chan struct{})}
w := &Worker{
ctx: context.Background(),
config: Config{TaskTimeout: time.Hour},
sendChan: make(chan []byte, 4),
activeTasks: make(map[string]context.CancelFunc),
backend: backend,
}

w.handleTaskAssignment(&types.TaskAssignmentMessage{
TaskID: "task-1",
Task: &types.Task{ID: "task-1", Title: "test task"},
})

waitForBackendCall(t, backend.called)

if backend.ctx == nil {
t.Fatal("expected backend to receive a context")
}
deadline, ok := backend.ctx.Deadline()
if !ok {
t.Fatal("expected backend context to have a deadline")
}
if time.Until(deadline) <= 0 || time.Until(deadline) > time.Hour {
t.Fatalf("deadline = %v, expected within the next hour", deadline)
}
}

func TestHandleTaskAssignmentLeavesTaskUnlimitedWhenTimeoutUnset(t *testing.T) {
backend := &contextRecordingBackend{called: make(chan struct{})}
w := &Worker{
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 4),
activeTasks: make(map[string]context.CancelFunc),
backend: backend,
}

w.handleTaskAssignment(&types.TaskAssignmentMessage{
TaskID: "task-1",
Task: &types.Task{ID: "task-1", Title: "test task"},
})

waitForBackendCall(t, backend.called)

if backend.ctx == nil {
t.Fatal("expected backend to receive a context")
}
if _, ok := backend.ctx.Deadline(); ok {
t.Fatal("expected backend context to have no deadline")
}
}

func waitForBackendCall(t *testing.T, called <-chan struct{}) {
t.Helper()

select {
case <-called:
case <-time.After(time.Second):
t.Fatal("timed out waiting for backend execution")
}
}

func readWebSocketMessage(t *testing.T, messages <-chan []byte) types.WebSocketMessage {
t.Helper()

Expand Down
18 changes: 18 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var CLI struct {
Volumes []string `help:"Volume mounts for task containers (format: HOST_PATH:CONTAINER_PATH or HOST_PATH:CONTAINER_PATH:MODE)" short:"v"`
Env []string `help:"Environment variables for task containers (format: KEY=VALUE or KEY to pass through from host)" short:"e"`
MaxConcurrentTasks int `help:"Maximum number of tasks to run concurrently (0 for unlimited)" default:"0"`
TaskTimeout string `help:"Maximum wall-clock runtime for each task (e.g. 2h, 90m, 0s for unlimited)"`
IdleOnComplete string `help:"How long to keep the oz agent alive after a task completes, for follow-ups (e.g. 45m, 10m, 0s). Defaults to 45m when not set."`
SessionSharingServerURL string `help:"Session sharing server WebSocket URL to pass through to the oz CLI (e.g. ws://127.0.0.1:8081)" hidden:""`
}
Expand Down Expand Up @@ -163,6 +164,22 @@ func mergeConfig(fileConfig *config.FileConfig) (worker.Config, error) {
maxConcurrentTasks = *fileConfig.MaxConcurrentTasks
}

// Resolve task_timeout: CLI (non-empty) > config file > 0 (unlimited).
taskTimeoutRaw := CLI.TaskTimeout
if taskTimeoutRaw == "" && fileConfig != nil && fileConfig.TaskTimeout != nil {
taskTimeoutRaw = *fileConfig.TaskTimeout
}
var taskTimeout time.Duration
if taskTimeoutRaw != "" {
taskTimeout, err = time.ParseDuration(taskTimeoutRaw)
if err != nil {
return worker.Config{}, fmt.Errorf("invalid task_timeout %q: %w", taskTimeoutRaw, err)
}
if taskTimeout < 0 {
return worker.Config{}, fmt.Errorf("invalid task_timeout %q: must be non-negative", taskTimeoutRaw)
}
}

// Resolve idle_on_complete: CLI (non-empty) > config file > "" (oz CLI default = 45m).
idleOnComplete := CLI.IdleOnComplete
if idleOnComplete == "" && fileConfig != nil && fileConfig.IdleOnComplete != nil {
Expand All @@ -177,6 +194,7 @@ func mergeConfig(fileConfig *config.FileConfig) (worker.Config, error) {
LogLevel: CLI.LogLevel,
BackendType: backendType,
MaxConcurrentTasks: maxConcurrentTasks,
TaskTimeout: taskTimeout,
IdleOnComplete: idleOnComplete,
SessionSharingServerURL: CLI.SessionSharingServerURL,
}
Expand Down
92 changes: 92 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ func resetCLIForTest() {
CLI.Volumes = nil
CLI.Env = nil
CLI.MaxConcurrentTasks = 0
CLI.TaskTimeout = ""
CLI.IdleOnComplete = ""
CLI.TargetDir = ""
CLI.SessionSharingServerURL = ""
}

func boolPtr(v bool) *bool {
Expand Down Expand Up @@ -196,3 +199,92 @@ func TestMergeConfigKubernetesAllowsZeroUnschedulableTimeout(t *testing.T) {
t.Fatalf("UnschedulableTimeout = %v, want 0", wc.Kubernetes.UnschedulableTimeout)
}
}

func TestMergeConfigTaskTimeout(t *testing.T) {
t.Run("uses file task_timeout", func(t *testing.T) {
resetCLIForTest()
t.Cleanup(resetCLIForTest)

fileConfig := &config.FileConfig{
WorkerID: "worker-123",
TaskTimeout: stringPtr("2h"),
}

wc, err := mergeConfig(fileConfig)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wc.TaskTimeout != 2*time.Hour {
t.Fatalf("TaskTimeout = %v, want 2h", wc.TaskTimeout)
}
})

t.Run("cli overrides file task_timeout", func(t *testing.T) {
resetCLIForTest()
t.Cleanup(resetCLIForTest)

CLI.WorkerID = "worker-123"
CLI.TaskTimeout = "30m"

fileConfig := &config.FileConfig{
WorkerID: "file-worker",
TaskTimeout: stringPtr("2h"),
}

wc, err := mergeConfig(fileConfig)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wc.TaskTimeout != 30*time.Minute {
t.Fatalf("TaskTimeout = %v, want 30m", wc.TaskTimeout)
}
})

t.Run("zero disables task_timeout", func(t *testing.T) {
resetCLIForTest()
t.Cleanup(resetCLIForTest)

fileConfig := &config.FileConfig{
WorkerID: "worker-123",
TaskTimeout: stringPtr("0s"),
}

wc, err := mergeConfig(fileConfig)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wc.TaskTimeout != 0 {
t.Fatalf("TaskTimeout = %v, want 0", wc.TaskTimeout)
}
})

t.Run("rejects invalid task_timeout", func(t *testing.T) {
resetCLIForTest()
t.Cleanup(resetCLIForTest)

fileConfig := &config.FileConfig{
WorkerID: "worker-123",
TaskTimeout: stringPtr("not-a-duration"),
}

_, err := mergeConfig(fileConfig)
if err == nil {
t.Fatal("expected error")
}
})

t.Run("rejects negative task_timeout", func(t *testing.T) {
resetCLIForTest()
t.Cleanup(resetCLIForTest)

fileConfig := &config.FileConfig{
WorkerID: "worker-123",
TaskTimeout: stringPtr("-1s"),
}

_, err := mergeConfig(fileConfig)
if err == nil {
t.Fatal("expected error")
}
})
}
Loading