Skip to content

Commit f6a8ed8

Browse files
committed
improve context management
1 parent 5d759c3 commit f6a8ed8

File tree

5 files changed

+59
-174
lines changed

5 files changed

+59
-174
lines changed

cadence/worker/_base_task_handler.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,19 @@ async def handle_task(self, task: T) -> None:
3434
Handle a single task.
3535
3636
This method provides the base implementation for task handling that includes:
37-
- Context propagation
3837
- Error handling
3938
- Cleanup
4039
4140
Args:
4241
task: The task to handle
4342
"""
4443
try:
45-
# Propagate context from task parameters
46-
await self._propagate_context(task)
47-
48-
# Handle the task
44+
# Handle the task implementation
4945
await self._handle_task_implementation(task)
5046

5147
except Exception as e:
5248
logger.exception(f"Error handling task: {e}")
5349
await self.handle_task_failure(task, e)
54-
finally:
55-
# Clean up context
56-
await self._unset_current_context()
5750

5851
@abstractmethod
5952
async def _handle_task_implementation(self, task: T) -> None:
@@ -75,20 +68,3 @@ async def handle_task_failure(self, task: T, error: Exception) -> None:
7568
error: The exception that occurred
7669
"""
7770
pass
78-
79-
async def _propagate_context(self, task: T) -> None:
80-
"""
81-
Propagate context from task parameters.
82-
83-
Args:
84-
task: The task containing context information
85-
"""
86-
# Default implementation - subclasses should override if needed
87-
pass
88-
89-
async def _unset_current_context(self) -> None:
90-
"""
91-
Unset the current context after task completion.
92-
"""
93-
# Default implementation - subclasses should override if needed
94-
pass

cadence/worker/_decision_task_handler.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cadence.client import Client
1212
from cadence.worker._base_task_handler import BaseTaskHandler
1313
from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult
14+
from cadence._internal.workflow.context import Context
1415
from cadence.workflow import WorkflowInfo
1516
from cadence.worker._registry import Registry
1617

@@ -52,8 +53,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
5253

5354
if not workflow_execution or not workflow_type:
5455
logger.error("Decision task missing workflow execution or type")
55-
await self.handle_task_failure(task, ValueError("Missing workflow execution or type"))
56-
return
56+
raise ValueError("Missing workflow execution or type")
5757

5858
workflow_id = workflow_execution.workflow_id
5959
run_id = workflow_execution.run_id
@@ -69,8 +69,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
6969
workflow_func = self._registry.get_workflow(workflow_type_name)
7070
except KeyError:
7171
logger.error(f"Workflow type '{workflow_type_name}' not found in registry")
72-
await self.handle_task_failure(task, KeyError(f"Workflow type '{workflow_type_name}' not found"))
73-
return
72+
raise KeyError(f"Workflow type '{workflow_type_name}' not found")
7473

7574
# Create workflow info and engine
7675
workflow_info = WorkflowInfo(
@@ -86,12 +85,22 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
8685
workflow_func=workflow_func
8786
)
8887

89-
# Process the decision using the workflow engine
88+
# Create workflow context and execute with it active
9089
workflow_engine = self._workflow_engines[engine_key]
91-
decision_result = await workflow_engine.process_decision(task)
90+
workflow_info = WorkflowInfo(
91+
workflow_type=workflow_type_name,
92+
workflow_domain=self._client.domain,
93+
workflow_id=workflow_id,
94+
workflow_run_id=run_id
95+
)
9296

93-
# Respond with the decisions
94-
await self._respond_decision_task_completed(task, decision_result)
97+
context = Context(client=self._client, info=workflow_info)
98+
with context._activate():
99+
# Process the decision using the workflow engine
100+
decision_result = await workflow_engine.process_decision(task)
101+
102+
# Respond with the decisions
103+
await self._respond_decision_task_completed(task, decision_result)
95104

96105
logger.info(f"Successfully processed decision task for workflow {workflow_id}")
97106

tests/cadence/worker/test_base_task_handler.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ def __init__(self, client, task_list: str, identity: str, **options):
1616
super().__init__(client, task_list, identity, **options)
1717
self._handle_task_implementation_called = False
1818
self._handle_task_failure_called = False
19-
self._propagate_context_called = False
20-
self._unset_current_context_called = False
2119
self._last_task: str = ""
2220
self._last_error: Exception | None = None
2321

@@ -33,15 +31,6 @@ async def handle_task_failure(self, task: str, error: Exception) -> None:
3331
self._handle_task_failure_called = True
3432
self._last_task = task
3533
self._last_error = error
36-
37-
async def _propagate_context(self, task: str) -> None:
38-
"""Test implementation of context propagation."""
39-
self._propagate_context_called = True
40-
self._last_task = task
41-
42-
async def _unset_current_context(self) -> None:
43-
"""Test implementation of context cleanup."""
44-
self._unset_current_context_called = True
4534

4635

4736
class TestBaseTaskHandler:
@@ -71,10 +60,8 @@ async def test_handle_task_success(self):
7160

7261
await handler.handle_task("test_task")
7362

74-
# Verify all methods were called in correct order
75-
assert handler._propagate_context_called
63+
# Verify implementation was called
7664
assert handler._handle_task_implementation_called
77-
assert handler._unset_current_context_called
7865
assert not handler._handle_task_failure_called
7966
assert handler._last_task == "test_task"
8067
assert handler._last_error is None
@@ -88,72 +75,12 @@ async def test_handle_task_failure(self):
8875
await handler.handle_task("raise_error")
8976

9077
# Verify error handling was called
91-
assert handler._propagate_context_called
9278
assert handler._handle_task_implementation_called
9379
assert handler._handle_task_failure_called
94-
assert handler._unset_current_context_called
9580
assert handler._last_task == "raise_error"
9681
assert isinstance(handler._last_error, ValueError)
9782
assert str(handler._last_error) == "Test error"
9883

99-
@pytest.mark.asyncio
100-
async def test_handle_task_with_context_propagation_error(self):
101-
"""Test task handling when context propagation fails."""
102-
client = Mock()
103-
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")
104-
105-
# Override _propagate_context to raise an error
106-
async def failing_propagate_context(task):
107-
raise RuntimeError("Context propagation failed")
108-
109-
# Use setattr to avoid mypy error about method assignment
110-
setattr(handler, '_propagate_context', failing_propagate_context)
111-
112-
await handler.handle_task("test_task")
113-
114-
# Verify error handling was called
115-
assert handler._handle_task_failure_called
116-
assert handler._unset_current_context_called
117-
assert isinstance(handler._last_error, RuntimeError)
118-
assert str(handler._last_error) == "Context propagation failed"
119-
120-
@pytest.mark.asyncio
121-
async def test_handle_task_with_cleanup_error(self):
122-
"""Test task handling when cleanup fails."""
123-
client = Mock()
124-
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")
125-
126-
# Override _unset_current_context to raise an error
127-
async def failing_unset_context():
128-
raise RuntimeError("Cleanup failed")
129-
130-
# Use setattr to avoid mypy error about method assignment
131-
setattr(handler, '_unset_current_context', failing_unset_context)
132-
133-
# Cleanup errors in finally block will propagate
134-
with pytest.raises(RuntimeError, match="Cleanup failed"):
135-
await handler.handle_task("test_task")
136-
137-
@pytest.mark.asyncio
138-
async def test_handle_task_with_implementation_and_cleanup_errors(self):
139-
"""Test task handling when both implementation and cleanup fail."""
140-
client = Mock()
141-
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")
142-
143-
# Override _unset_current_context to raise an error
144-
async def failing_unset_context():
145-
raise RuntimeError("Cleanup failed")
146-
147-
# Use setattr to avoid mypy error about method assignment
148-
setattr(handler, '_unset_current_context', failing_unset_context)
149-
150-
# The implementation error should be handled, but cleanup error will propagate
151-
with pytest.raises(RuntimeError, match="Cleanup failed"):
152-
await handler.handle_task("raise_error")
153-
154-
# Verify the implementation error was handled before cleanup error
155-
assert handler._handle_task_failure_called
156-
assert isinstance(handler._last_error, ValueError)
15784

15885
@pytest.mark.asyncio
15986
async def test_abstract_methods_not_implemented(self):
@@ -175,17 +102,6 @@ async def handle_task_failure(self, task: str, error: Exception) -> None:
175102
with pytest.raises(NotImplementedError):
176103
await handler.handle_task_failure("test", Exception("test"))
177104

178-
@pytest.mark.asyncio
179-
async def test_default_context_methods(self):
180-
"""Test default implementations of context methods."""
181-
client = Mock()
182-
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")
183-
184-
# Test default _propagate_context (should not raise)
185-
await handler._propagate_context("test_task")
186-
187-
# Test default _unset_current_context (should not raise)
188-
await handler._unset_current_context()
189105

190106
@pytest.mark.asyncio
191107
async def test_generic_type_parameter(self):

tests/cadence/worker/test_decision_task_handler.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,8 @@ async def test_handle_task_implementation_missing_workflow_execution(self, handl
114114
task.workflow_type = Mock()
115115
task.workflow_type.name = "TestWorkflow"
116116

117-
with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure:
117+
with pytest.raises(ValueError, match="Missing workflow execution or type"):
118118
await handler._handle_task_implementation(task)
119-
120-
mock_handle_failure.assert_called_once()
121-
args = mock_handle_failure.call_args[0]
122-
assert args[0] == task
123-
assert isinstance(args[1], ValueError)
124-
assert "Missing workflow execution or type" in str(args[1])
125119

126120
@pytest.mark.asyncio
127121
async def test_handle_task_implementation_missing_workflow_type(self, handler):
@@ -133,28 +127,16 @@ async def test_handle_task_implementation_missing_workflow_type(self, handler):
133127
task.workflow_execution.run_id = "test_run_id"
134128
task.workflow_type = None
135129

136-
with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure:
130+
with pytest.raises(ValueError, match="Missing workflow execution or type"):
137131
await handler._handle_task_implementation(task)
138-
139-
mock_handle_failure.assert_called_once()
140-
args = mock_handle_failure.call_args[0]
141-
assert args[0] == task
142-
assert isinstance(args[1], ValueError)
143-
assert "Missing workflow execution or type" in str(args[1])
144132

145133
@pytest.mark.asyncio
146134
async def test_handle_task_implementation_workflow_not_found(self, handler, sample_decision_task, mock_registry):
147135
"""Test decision task handling when workflow is not found in registry."""
148136
mock_registry.get_workflow.side_effect = KeyError("Workflow not found")
149137

150-
with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure:
138+
with pytest.raises(KeyError, match="Workflow type 'TestWorkflow' not found"):
151139
await handler._handle_task_implementation(sample_decision_task)
152-
153-
mock_handle_failure.assert_called_once()
154-
args = mock_handle_failure.call_args[0]
155-
assert args[0] == sample_decision_task
156-
assert isinstance(args[1], KeyError)
157-
assert "Workflow type 'TestWorkflow' not found" in str(args[1])
158140

159141
@pytest.mark.asyncio
160142
async def test_handle_task_implementation_reuses_existing_engine(self, handler, sample_decision_task, mock_registry):
@@ -340,13 +322,15 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample
340322
with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class:
341323
await handler._handle_task_implementation(sample_decision_task)
342324

343-
# Verify WorkflowInfo was created with correct parameters
344-
mock_workflow_info_class.assert_called_once_with(
345-
workflow_type="TestWorkflow",
346-
workflow_domain="test_domain",
347-
workflow_id="test_workflow_id",
348-
workflow_run_id="test_run_id"
349-
)
325+
# Verify WorkflowInfo was created with correct parameters (called twice - once for engine, once for context)
326+
assert mock_workflow_info_class.call_count == 2
327+
for call in mock_workflow_info_class.call_args_list:
328+
assert call[1] == {
329+
'workflow_type': "TestWorkflow",
330+
'workflow_domain': "test_domain",
331+
'workflow_id': "test_workflow_id",
332+
'workflow_run_id': "test_run_id"
333+
}
350334

351335
# Verify WorkflowEngine was created with correct parameters
352336
mock_workflow_engine_class.assert_called_once()

tests/cadence/worker/test_task_handler_integration.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import pytest
7+
from contextlib import contextmanager
78
from unittest.mock import Mock, AsyncMock, patch, PropertyMock
89

910
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
@@ -100,8 +101,8 @@ async def test_full_task_handling_flow_with_error(self, handler, sample_decision
100101
assert call_args.identity == handler._identity
101102

102103
@pytest.mark.asyncio
103-
async def test_context_propagation_integration(self, handler, sample_decision_task, mock_registry):
104-
"""Test that context propagation works correctly in the integration."""
104+
async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry):
105+
"""Test that context activation works correctly in the integration."""
105106
# Mock workflow function
106107
mock_workflow_func = Mock()
107108
mock_registry.get_workflow.return_value = mock_workflow_func
@@ -114,27 +115,23 @@ async def test_context_propagation_integration(self, handler, sample_decision_ta
114115
mock_decision_result.query_results = {}
115116
mock_engine.process_decision = AsyncMock(return_value=mock_decision_result)
116117

117-
# Track if context methods are called
118-
context_propagated = False
119-
context_unset = False
118+
# Track if context is activated
119+
context_activated = False
120120

121-
async def track_propagate_context(task):
122-
nonlocal context_propagated
123-
context_propagated = True
124-
125-
async def track_unset_current_context():
126-
nonlocal context_unset
127-
context_unset = True
128-
129-
handler._propagate_context = track_propagate_context
130-
handler._unset_current_context = track_unset_current_context
121+
def track_context_activation():
122+
nonlocal context_activated
123+
context_activated = True
131124

132125
with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine):
133-
await handler.handle_task(sample_decision_task)
134-
135-
# Verify context methods were called
136-
assert context_propagated
137-
assert context_unset
126+
with patch('cadence.worker._decision_task_handler.Context') as mock_context_class:
127+
mock_context = Mock()
128+
mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_activation())())
129+
mock_context_class.return_value = mock_context
130+
131+
await handler.handle_task(sample_decision_task)
132+
133+
# Verify context was activated
134+
assert context_activated
138135

139136
@pytest.mark.asyncio
140137
async def test_multiple_workflow_executions(self, handler, mock_registry):
@@ -235,19 +232,22 @@ async def test_error_handling_with_context_cleanup(self, handler, sample_decisio
235232
mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed"))
236233

237234
# Track context cleanup
238-
context_unset = False
239-
240-
async def track_unset_current_context():
241-
nonlocal context_unset
242-
context_unset = True
235+
context_cleaned_up = False
243236

244-
handler._unset_current_context = track_unset_current_context
237+
def track_context_cleanup():
238+
nonlocal context_cleaned_up
239+
context_cleaned_up = True
245240

246241
with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine):
247-
await handler.handle_task(sample_decision_task)
242+
with patch('cadence.worker._decision_task_handler.Context') as mock_context_class:
243+
mock_context = Mock()
244+
mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_cleanup())())
245+
mock_context_class.return_value = mock_context
246+
247+
await handler.handle_task(sample_decision_task)
248248

249249
# Verify context was cleaned up even after error
250-
assert context_unset
250+
assert context_cleaned_up
251251

252252
# Verify error was handled
253253
handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once()

0 commit comments

Comments
 (0)