3
3
import logging
4
4
import threading
5
5
from enum import Enum
6
- from typing import Callable , Dict
6
+ from typing import Callable , Dict , Optional
7
7
from uuid import uuid4
8
8
9
9
import pika
@@ -37,18 +37,21 @@ def from_workflow_type(workflow_type: WorkFlowType) -> "Queue":
37
37
38
38
39
39
class RabbitmqClient (threading .Thread ):
40
+ rabbitmq_callbacks : Dict [Queue , PikaCallback ]
40
41
rabbitmq_is_running : bool
41
42
rabbitmq_config : RabbitmqConfig
42
43
rabbitmq_exchange : str
43
- connection : pika .BlockingConnection
44
- channel : BlockingChannel
45
- queue : str
44
+ rabbitmq_connection : Optional [pika .BlockingConnection ]
45
+ rabbitmq_channel : Optional [BlockingChannel ]
46
46
47
47
def __init__ (self , config : RabbitmqConfig ):
48
48
super ().__init__ ()
49
+ self .rabbitmq_callbacks = {}
49
50
self .rabbitmq_is_running = False
50
51
self .rabbitmq_config = config
51
52
self .rabbitmq_exchange = config .exchange_name
53
+ self .rabbitmq_connection = None
54
+ self .rabbitmq_channel = None
52
55
53
56
def _connect_rabbitmq (self ):
54
57
# initialize rabbitmq connection
@@ -69,25 +72,33 @@ def _connect_rabbitmq(self):
69
72
connection_attempts = 10 ,
70
73
)
71
74
72
- self .connection = pika .BlockingConnection (parameters )
73
-
74
- self .channel = self .connection .channel ()
75
- self .channel .basic_qos (prefetch_size = 0 , prefetch_count = 1 )
76
- self .channel .exchange_declare (exchange = self .rabbitmq_exchange , exchange_type = "topic" )
77
- for queue_item in Queue :
78
- queue = self .channel .queue_declare (queue_item .value , exclusive = False ).method .queue
79
- self .channel .queue_bind (queue , self .rabbitmq_exchange , routing_key = queue_item .value )
75
+ if not self .rabbitmq_connection or self .rabbitmq_connection .is_closed :
76
+ LOGGER .info ("Setting up a new connection to RabbitMQ." )
77
+ self .rabbitmq_connection = pika .BlockingConnection (parameters )
78
+
79
+ if not self .rabbitmq_channel or self .rabbitmq_channel .is_closed :
80
+ LOGGER .info ("Setting up a new channel to RabbitMQ." )
81
+ self .rabbitmq_channel = self .rabbitmq_connection .channel ()
82
+ self .rabbitmq_channel .basic_qos (prefetch_size = 0 , prefetch_count = 1 )
83
+ self .rabbitmq_channel .exchange_declare (exchange = self .rabbitmq_exchange , exchange_type = "topic" )
84
+ for queue_item in Queue :
85
+ queue = self .rabbitmq_channel .queue_declare (queue_item .value , exclusive = False ).method .queue
86
+ self .rabbitmq_channel .queue_bind (queue , self .rabbitmq_exchange , routing_key = queue_item .value )
87
+
88
+ for queue , callback in self .rabbitmq_callbacks .items ():
89
+ self .rabbitmq_channel .basic_consume (queue = queue .value , on_message_callback = callback , auto_ack = False )
80
90
LOGGER .info ("Connected to RabbitMQ" )
81
91
82
92
def _start_rabbitmq (self ):
83
93
self ._connect_rabbitmq ()
84
94
self .start ()
85
95
86
96
def set_callbacks (self , callbacks : Dict [Queue , PikaCallback ]):
97
+ self .rabbitmq_callbacks .update (callbacks )
87
98
for queue , callback in callbacks .items ():
88
- self .connection .add_callback_threadsafe (
99
+ self .rabbitmq_connection .add_callback_threadsafe (
89
100
functools .partial (
90
- self .channel .basic_consume , queue = queue .value , on_message_callback = callback , auto_ack = False
101
+ self .rabbitmq_channel .basic_consume , queue = queue .value , on_message_callback = callback , auto_ack = False
91
102
)
92
103
)
93
104
@@ -98,9 +109,12 @@ def run(self):
98
109
try :
99
110
LOGGER .info ("Waiting for input..." )
100
111
while self .rabbitmq_is_running :
101
- self .connection . process_data_events (time_limit = 1 )
112
+ self .rabbitmq_channel . _process_data_events (time_limit = 1 )
102
113
except pika .exceptions .ConnectionClosedByBroker as exc :
103
114
LOGGER .info ('Connection was closed by broker. Reason: "%s". Shutting down...' , exc .reply_text )
115
+ except pika .exceptions .ChannelClosedByBroker as exc :
116
+ LOGGER .info ('Channel was closed by broker. Reason: "%s". retrying...' , exc .reply_text )
117
+ self ._connect_rabbitmq ()
104
118
except pika .exceptions .AMQPConnectionError :
105
119
LOGGER .info ("Connection was lost, retrying..." )
106
120
self ._connect_rabbitmq ()
@@ -113,13 +127,13 @@ def _send_start_work_flow(self, job_id: uuid4, work_flow_type: WorkFlowType):
113
127
114
128
def _send_output (self , queue : Queue , message : str ):
115
129
body : bytes = message .encode ("utf-8" )
116
- self .connection .add_callback_threadsafe (
130
+ self .rabbitmq_connection .add_callback_threadsafe (
117
131
functools .partial (
118
- self .channel .basic_publish , exchange = self .rabbitmq_exchange , routing_key = queue .value , body = body
132
+ self .rabbitmq_channel .basic_publish , exchange = self .rabbitmq_exchange , routing_key = queue .value , body = body
119
133
)
120
134
)
121
135
122
136
def _stop_rabbitmq (self ):
123
137
self .rabbitmq_is_running = False
124
- if self .connection :
125
- self .connection .add_callback_threadsafe (self .connection .close )
138
+ if self .rabbitmq_connection :
139
+ self .rabbitmq_connection .add_callback_threadsafe (self .rabbitmq_connection .close )
0 commit comments