Skip to content

Commit 6e53b52

Browse files
committed
Use HeartBeatTable to trace host state
1 parent df8de52 commit 6e53b52

File tree

4 files changed

+202
-75
lines changed

4 files changed

+202
-75
lines changed

heartbeat_table.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import threading
2+
from collections import namedtuple
3+
from datetime import datetime
4+
from datetime import timedelta
5+
import logging_aux
6+
7+
8+
class HeartBeatTable(object):
9+
10+
def __init__(self, provisioning_timeout=timedelta(minutes=15), heartbeat_timeout=timedelta(minutes=3)):
11+
self._table = {}
12+
self.logger = logging_aux.init_logger_aux(
13+
"hpcframework.heartbeat", "hpcframework.heartbeat.log")
14+
self.on_host_running = []
15+
self._table_lock = threading.Lock()
16+
self._provisioning_timeout = provisioning_timeout
17+
self._heartbeat_timeout = heartbeat_timeout
18+
19+
def add_slaveinfo(self, hostname, agent_id, task_id, cpus, last_heartbeat=datetime.utcnow()):
20+
if hostname in self._table and self._table[hostname].state != HpcState.Closed:
21+
self.logger.warn("Heart beat entry of {} existed. old value: {}.".format(
22+
hostname, str(self._table[hostname])))
23+
slaveinfo = SlaveInfo(hostname, agent_id, task_id,
24+
cpus, last_heartbeat, HpcState.Provisioning)
25+
self._table[hostname] = slaveinfo
26+
self.logger.info("Heart beat entry added: {}".format(str(slaveinfo)))
27+
28+
def on_slave_heartbeat(self, hostname):
29+
if hostname in self._table:
30+
self._table[hostname].last_heartbeat = datetime.utcnow()
31+
self.logger.info("Heatbeat from host {}".format(hostname))
32+
if self._table[hostname].state == HpcState.Provisioning:
33+
with self._table_lock: # to ensure we only run running callback once per entry
34+
if self._table[hostname].state == HpcState.Provisioning:
35+
self._table[hostname].state = HpcState.Running
36+
self.__exec_callback(self.on_host_running)
37+
self.logger.info(
38+
"Host {} start running".format(hostname))
39+
else:
40+
self.logger.error(
41+
"Host {} is not recognized. Heartbeat ignored.".format(hostname))
42+
43+
def on_slave_close(self, hostname):
44+
if hostname in self._table:
45+
self._table[hostname].state = HpcState.Closed
46+
self.logger.info("Host {} closed".format(hostname))
47+
else:
48+
self.logger.error(
49+
"Host {} is not recognized. Close event ignored.".format(hostname))
50+
51+
def get_task_info(self, hostname):
52+
if hostname in self._table:
53+
entry = self._table[hostname]
54+
return (entry.task_id, entry.agent_id)
55+
else:
56+
self.logger.error(
57+
"Host {} is not recognized. Failed to get task info.".format(hostname))
58+
59+
def __exec_callback(self, callbacks):
60+
for callback in callbacks:
61+
try:
62+
self.logger.debug(
63+
'Callback %s on %s' % (callback.__name__)
64+
)
65+
callback()
66+
except Exception as e:
67+
self.logger.exception(
68+
'Error in %s callback: %s' % (callback.__name__, str(e))
69+
)
70+
71+
def check_timeout(self, now=datetime.utcnow()):
72+
provision_timeout_list = []
73+
heartbeat_timeout_list = []
74+
running_list = []
75+
for host in dict(self._table):
76+
if host.state == HpcState.Provisioning and now - host.heartbeat_timeout >= self._provisioning_timeout:
77+
provision_timeout_list.append(host)
78+
elif host.state == HpcState.Running:
79+
if now - host.heartbeat_timeout >= self._heartbeat_timeout:
80+
heartbeat_timeout_list.append(host)
81+
else:
82+
running_list.append(host)
83+
return (provision_timeout_list, heartbeat_timeout_list, running_list)
84+
85+
def get_cores_in_provisioning(self):
86+
cores = 0.0
87+
for host in dict(self._table):
88+
if host.state == HpcState.Provisioning:
89+
cores += host.cpus
90+
return cores
91+
92+
93+
SlaveInfo = namedtuple(
94+
"SlaveInfo", "hostname agent_id task_id cpus last_heartbeat state")
95+
96+
97+
class HpcState:
98+
Provisioning, Running, Closed = range(3)

hpcframework.py

+54-48
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1+
import base64
2+
import codecs
13
import json
2-
import datetime
3-
import time
4-
import os
5-
import sys
6-
import threading
74
import logging
5+
import os
86
import signal
97
import sys
8+
import threading
9+
import time
1010
import uuid
11-
import base64
12-
import codecs
13-
import restserver
14-
import restclient
15-
import logging_aux
11+
from collections import namedtuple
12+
from datetime import datetime
1613

1714
from mesoshttp.client import MesosClient
15+
from mesoshttp.offers import Offer
16+
17+
import logging_aux
18+
import restclient
19+
import restserver
1820
from restclient import AutoScaleRestClient
21+
import heartbeat_table
1922

20-
from mesoshttp.offers import Offer
2123

2224
class Test(object):
2325
class MesosFramework(threading.Thread):
@@ -39,15 +41,17 @@ def __init__(self):
3941
# signal.signal(signal.SIGINT, signal.SIG_IGN)
4042
logging.getLogger('mesoshttp').setLevel(logging.DEBUG)
4143

42-
self.hpc_client = AutoScaleRestClient()
43-
self.core_provisioning = 0.0
44+
self.heartbeat_table = heartbeat_table.HeartBeatTable()
45+
46+
self.hpc_client = AutoScaleRestClient()
47+
self.core_provisioning = 0.0
4448
with open("setupscript.ps1") as scriptfile:
4549
hpc_setup_ps1 = scriptfile.read()
46-
self.logger.info("Loaded HPC setup script:/n{}".format(hpc_setup_ps1))
50+
self.logger.info("Loaded HPC setup script:\n{}".format(hpc_setup_ps1))
4751
hpc_setup_ps1_utf16 = hpc_setup_ps1.encode('utf-16')
48-
hpc_setup_ps1_utf16_nobom = hpc_setup_ps1_utf16[2:] if hpc_setup_ps1_utf16[0:2] == codecs.BOM_UTF16 else hpc_setup_ps1_utf16
52+
hpc_setup_ps1_utf16_nobom = hpc_setup_ps1_utf16[2:] if hpc_setup_ps1_utf16[
53+
0:2] == codecs.BOM_UTF16 else hpc_setup_ps1_utf16
4954
self.hpc_setup_ps1_b64 = base64.b64encode(hpc_setup_ps1_utf16_nobom)
50-
5155

5256
self.driver = None # type: MesosClient.SchedulerDriver
5357
self.mesos_client = MesosClient(mesos_urls=['http://172.16.1.4:5050'])
@@ -57,6 +61,10 @@ def __init__(self):
5761
self.mesos_client.on(MesosClient.UPDATE, self.status_update)
5862
self.th = Test.MesosFramework(self.mesos_client)
5963
self.th.start()
64+
65+
self.heartbeat_server = restserver.RestServer(self.heartbeat_table ,8088)
66+
self.heartbeat_server.start()
67+
6068
while True and self.th.isAlive():
6169
try:
6270
self.th.join(1)
@@ -65,11 +73,12 @@ def __init__(self):
6573
break
6674

6775
def shutdown(self):
68-
print('Stop requested by user, stopping framework....')
76+
print 'Stop requested by user, stopping framework....'
6977
self.logger.warn('Stop requested by user, stopping framework....')
7078
self.driver.tearDown()
7179
self.mesos_client.stop = True
7280
self.stop = True
81+
self.heartbeat_server.stop()
7382

7483
def subscribed(self, driver):
7584
self.logger.warn('SUBSCRIBED')
@@ -78,24 +87,25 @@ def subscribed(self, driver):
7887
def status_update(self, update):
7988
# if update['status']['state'] == 'TASK_RUNNING':
8089
# self.driver.kill(update['status']['agent_id']['value'], update['status']['task_id']['value'])
81-
print(str(update))
90+
self.logger.log("Update received:\n{}".format(str(update)))
8291

8392
def offer_received(self, offers):
8493
# self.logger.info('OFFER: %s' % (str(offers)))
8594
grow_decision = self.hpc_client.get_grow_decision()
86-
87-
if(grow_decision.cores_to_grow - self.core_provisioning > 0):
88-
for offer in offers: # type: Offer
89-
self.logger.info("offer_received: {}".format(
90-
(str(offer.get_offer()))))
95+
96+
if grow_decision.cores_to_grow - self.core_provisioning > 0:
97+
for offer in offers: # type: Offer
9198
mesos_offer = offer.get_offer()
99+
self.logger.info("offer_received: {}".format(
100+
(str(mesos_offer))))
92101
if 'attributes' in mesos_offer:
93102
attributes = mesos_offer['attributes']
94103
if self.get_text(attributes, 'os') != 'windows_server':
95104
offer.decline()
96105
else:
97106
cores = self.get_scalar(attributes, 'cores')
98-
cpus = self.get_scalar(mesos_offer['resources'], 'cpus')
107+
cpus = self.get_scalar(
108+
mesos_offer['resources'], 'cpus')
99109

100110
if cores == cpus:
101111
self.accept_offer(offer)
@@ -105,44 +115,44 @@ def offer_received(self, offers):
105115
offer.decline()
106116
else:
107117
for offer in offers:
108-
offer.decline()
118+
offer.decline()
109119

110120
def accept_offer(self, offer):
111-
self.logger.info("Offer %s meets hpc's requiremnt" %
121+
self.logger.info("Offer %s meets HPC's requirement" %
112122
offer.get_offer()['id']['value'])
113123
self.run_job(offer)
114-
115-
# i = 0
116-
# for offer in offers:
117-
# if i == 0:
118-
# self.run_job(offer)
119-
# else:
120-
# offer.decline()
121-
# i+=1
122-
def get_scalar(self, dict, name):
123-
for i in dict:
124+
125+
def get_scalar(self, collection, name):
126+
for i in collection:
124127
if i['name'] == name:
125128
return i['scalar']['value']
126129
return 0.0
127130

128-
def get_text(self, dict, name):
129-
for i in dict:
131+
def get_text(self, collection, name):
132+
for i in collection:
130133
if i['name'] == name:
131134
return i['text']['value']
132135
return ""
133136

134137
def run_job(self, mesos_offer):
135138
offer = mesos_offer.get_offer()
136139
self.logger.info("Accepting offer: {}".format(str(offer)))
140+
141+
agent_id = offer['agent_id']['value']
142+
hostname = offer['hostname']
143+
task_id = uuid.uuid4().hex
144+
cpus = self.get_scalar(offer['resources'], 'cpus')
145+
137146
task = {
138147
'name': 'sample test',
139-
'task_id': {'value': uuid.uuid4().hex},
140-
'agent_id': {'value': offer['agent_id']['value']},
148+
'task_id': {'value': task_id},
149+
'agent_id': {'value': agent_id},
141150
'resources': [
142151
{
143152
'name': 'cpus',
144153
'type': 'SCALAR',
145-
'scalar': {'value': self.get_scalar(offer['resources'], 'cpus') - 0.1}
154+
# work around of MESOS-8631
155+
'scalar': {'value': cpus - 0.1}
146156
},
147157
{
148158
'name': 'mem',
@@ -152,14 +162,10 @@ def run_job(self, mesos_offer):
152162
],
153163
'command': {'value': 'powershell -EncodedCommand ' + self.hpc_setup_ps1_b64}
154164
}
155-
self.logger.debug("Sending command:/n{}".format(task['command']['value']))
165+
self.logger.debug(
166+
"Sending command:\n{}".format(task['command']['value']))
156167
mesos_offer.accept([task])
157-
168+
self.heartbeat_table.add_slaveinfo(hostname, agent_id, task, cpus)
158169

159170
if __name__ == "__main__":
160-
rest_server = restserver.RestServer(8088)
161-
server_thread = threading.Thread(target=rest_server.run)
162-
server_thread.start()
163171
test_mesos = Test()
164-
rest_server.stop()
165-
server_thread.join()

restserver.py

+49-26
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,64 @@
11
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
22
import json
3+
import threading
4+
from heartbeat_table import HeartBeatTable
5+
import logging_aux
6+
import logging
37

4-
class RestServer(object):
5-
class S(BaseHTTPRequestHandler):
6-
def _set_headers(self):
7-
self.send_response(200)
8-
self.send_header('Content-type', 'text/html')
9-
self.end_headers()
10-
11-
def do_GET(self):
12-
self._set_headers()
13-
self.wfile.write("<html><body><h1>hi!</h1></body></html>")
14-
15-
def do_HEAD(self):
16-
self._set_headers()
17-
18-
def do_POST(self):
19-
# Doesn't do anything with posted data
20-
content_length = int(self.headers['Content-Length']) # <--- Gets the size of data
21-
post_data = self.rfile.read(content_length) # <--- Gets the data itself
22-
self._set_headers()
23-
json_obj = json.loads(post_data)
24-
self.wfile.write("<html><body><h1>POST!</h1><pre>" + str(json_obj) + "</pre></body></html>")
25-
26-
def __init__(self, port = 80):
8+
class RestServer(object): # TODO: replace this implementation with twisted based implementation
9+
def __init__(self, heartbeat_table, port = 80):
10+
self.logger = logging_aux.init_logger_aux(
11+
"hpcframework.heatbeat_server", "hpcframework.heatbeat_server.log")
12+
self._heartbeat_table = heartbeat_table # type: HeartBeatTable
2713
self._server_address = ('', port)
28-
self._server_class = HTTPServer
29-
self._handler_class = self.S
14+
self._server_class = HTTPServer
15+
self._handler_class = HeartBeatHandler
3016
self._port = port
3117
self._httpd = self._server_class(self._server_address, self._handler_class)
18+
self._server_thread = threading.Thread(target=self.run)
19+
HeartBeatHandler.logger = self.logger
20+
HeartBeatHandler.heartbeat_table = self._heartbeat_table
3221

3322
def run(self):
34-
print 'Starting httpd...'
23+
self.logger.debug('Starting httpd...')
3524
self._httpd.serve_forever()
3625

3726
def stop(self):
3827
self._httpd.shutdown()
28+
self._server_thread.join()
29+
30+
def start(self):
31+
self._server_thread.start()
32+
33+
class HeartBeatHandler(BaseHTTPRequestHandler):
34+
logger = None # type: logging.Logger
35+
heartbeat_table = None # type: HeartBeatTable
36+
37+
def _set_headers(self):
38+
self.send_response(200)
39+
self.send_header('Content-type', 'text/html')
40+
self.end_headers()
41+
42+
def do_GET(self):
43+
self._set_headers()
44+
self.wfile.write("<html><body><h1>hi from thread {}!</h1></body></html>".format(threading._get_ident()))
45+
46+
def do_HEAD(self):
47+
self._set_headers()
48+
49+
def do_POST(self):
50+
# Doesn't do anything with posted data
51+
content_length = int(self.headers['Content-Length']) # <--- Gets the size of data
52+
post_data = self.rfile.read(content_length) # <--- Gets the data itself
53+
self._set_headers()
54+
json_obj = json.loads(post_data)
55+
# self.wfile.write("<html><body><h1>POST!</h1><pre>" + str(json_obj) + "</pre></body></html>")
56+
self.logger.debug("Received heartbeat object {}".format(str(json_obj)))
57+
try:
58+
self.heartbeat_table.on_slave_heartbeat(json_obj['hostname'])
59+
except Exception as ex:
60+
self.logger.exception(ex)
61+
3962

4063
# if __name__ == "__main__":
4164
# from sys import argv

setupscript.ps1

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ while (!$broughtOnline -and ($retryCount -lt 120)) {
1515
catch {
1616
Write-host "Wait for 5 secs and then retry"
1717
++$retryCount
18-
sleep 5
18+
Start-Sleep 5
1919
}
2020
}
2121

0 commit comments

Comments
 (0)