2
2
3
3
import logging
4
4
from enum import Enum
5
- import signal
6
5
from typing import Callable , Dict
7
6
from uuid import uuid4
8
7
9
8
import pika
9
+ import pika .exceptions
10
10
import json
11
11
12
12
from pika .adapters .blocking_connection import BlockingChannel
@@ -33,49 +33,74 @@ def from_workflow_type(workflow_type: WorkFlowType) -> "Queue":
33
33
34
34
35
35
class RabbitmqClient :
36
+ rabbitmq_is_running : bool
37
+ rabbitmq_config : RabbitmqConfig
36
38
rabbitmq_exchange : str
39
+ connection : pika .BlockingConnection
37
40
channel : BlockingChannel
38
41
queue : str
39
42
40
43
def __init__ (self , config : RabbitmqConfig ):
44
+ self .rabbitmq_is_running = False
45
+ self .rabbitmq_config = config
41
46
self .rabbitmq_exchange = config .exchange_name
42
47
48
+ def _connect_rabbitmq (self ):
43
49
# initialize rabbitmq connection
44
- LOGGER .info ("Connecting to RabbitMQ at %s:%s as user %s" , config .host , config .port , config .user_name )
45
- credentials = pika .PlainCredentials (config .user_name , config .password )
50
+ LOGGER .info (
51
+ "Connecting to RabbitMQ at %s:%s as user %s" ,
52
+ self .rabbitmq_config .host ,
53
+ self .rabbitmq_config .port ,
54
+ self .rabbitmq_config .user_name ,
55
+ )
56
+ credentials = pika .PlainCredentials (self .rabbitmq_config .user_name , self .rabbitmq_config .password )
46
57
parameters = pika .ConnectionParameters (
47
- config .host , config .port , "/" , credentials , heartbeat = 3600 , blocked_connection_timeout = 3600
58
+ self .rabbitmq_config .host ,
59
+ self .rabbitmq_config .port ,
60
+ "/" ,
61
+ credentials ,
62
+ heartbeat = 3600 ,
63
+ blocked_connection_timeout = 3600 ,
64
+ connection_attempts = 10 ,
48
65
)
49
66
50
- connection = pika .BlockingConnection (parameters )
67
+ self . connection = pika .BlockingConnection (parameters )
51
68
52
- self .channel = connection .channel ()
69
+ self .channel = self . connection .channel ()
53
70
self .channel .basic_qos (prefetch_size = 0 , prefetch_count = 1 )
54
71
self .channel .exchange_declare (exchange = self .rabbitmq_exchange , exchange_type = "topic" )
55
72
self .queue = self .channel .queue_declare (Queue .StartWorkflowOptimizer .value , exclusive = False ).method .queue
56
73
self .channel .queue_bind (self .queue , self .rabbitmq_exchange , routing_key = Queue .StartWorkflowOptimizer .value )
57
74
LOGGER .info ("Connected to RabbitMQ" )
58
75
59
- def wait_for_data (self , callbacks : Dict [Queue , PikaCallback ]):
60
- for queue , callback in callbacks .items ():
61
- self .channel .basic_consume (queue = queue .value , on_message_callback = callback , auto_ack = False )
62
-
63
- def stop (signal , frame ):
64
- LOGGER .info ("Received signal %s. Stopping.." , signal )
65
- self .channel .stop_consuming ()
66
-
67
- signal .signal (signal .SIGINT , stop )
68
- signal .signal (signal .SIGTERM , stop )
69
-
70
- LOGGER .info ("Waiting for input..." )
71
- self .channel .start_consuming ()
72
-
73
- def send_start_work_flow (self , job_id : uuid4 , work_flow_type : WorkFlowType ):
76
+ def wait_for_work (self , callbacks : Dict [Queue , PikaCallback ]):
77
+ self .rabbitmq_is_running = True
78
+
79
+ while self .rabbitmq_is_running :
80
+ try :
81
+ for queue , callback in callbacks .items ():
82
+ self .channel .basic_consume (queue = queue .value , on_message_callback = callback , auto_ack = False )
83
+ LOGGER .info ("Waiting for input..." )
84
+ self .channel .start_consuming ()
85
+ except pika .exceptions .ConnectionClosedByBroker as exc :
86
+ LOGGER .info ('Connection was closed by broker. Reason: "%s". Shutting down...' , exc .reply_text )
87
+ except pika .exceptions .AMQPConnectionError :
88
+ LOGGER .info ("Connection was lost, retrying..." )
89
+ self ._connect_rabbitmq ()
90
+
91
+ def _send_start_work_flow (self , job_id : uuid4 , work_flow_type : WorkFlowType ):
74
92
# TODO convert to protobuf
75
93
# TODO job_id converted to string for json
76
94
body = json .dumps ({"job_id" : str (job_id )})
77
- self .send_output (Queue .from_workflow_type (work_flow_type ), body )
95
+ self ._send_output (Queue .from_workflow_type (work_flow_type ), body )
78
96
79
- def send_output (self , queue : Queue , message : str ):
97
+ def _send_output (self , queue : Queue , message : str ):
80
98
body : bytes = message .encode ("utf-8" )
81
99
self .channel .basic_publish (exchange = self .rabbitmq_exchange , routing_key = queue .value , body = body )
100
+
101
+ def _stop_rabbitmq (self ):
102
+ self .rabbitmq_is_running = False
103
+ if self .channel :
104
+ self .channel .stop_consuming ()
105
+ if self .connection :
106
+ self .connection .close ()
0 commit comments