-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathresponder.py
304 lines (272 loc) · 12 KB
/
responder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import asyncio
import logging
import time
import collections
import sys
import os
import socket
from functools import partial
from .resolver import STSResolver, STSFetchResult
from .constants import QUEUE_LIMIT, CHUNK, REQUEST_LIMIT
from .utils import create_custom_socket, filter_domain, is_ipaddr
from .base_cache import CacheEntry
from . import netstring
REQUEST_ENCODING = 'utf-8'
ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver', 'require_sni', 'tlsrpt'))
# pylint: disable=too-many-instance-attributes
class STSSocketmapResponder:
def __init__(self, cfg, loop, cache):
self._logger = logging.getLogger("STS")
self._loop = loop
if cfg.get('path') is not None:
self._unix = True
self._path = cfg['path']
self._sockmode = cfg.get('mode')
else:
self._unix = False
self._host = cfg['host']
self._port = cfg['port']
self._reuse_port = cfg['reuse_port']
self._shutdown_timeout = cfg['shutdown_timeout']
self._grace = cfg['cache_grace']
# Construct configurations and resolvers for every socketmap name
self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"],
STSResolver(loop=loop,
timeout=cfg["default_zone"]["timeout"]),
cfg["default_zone"]["require_sni"],
cfg["default_zone"]["tlsrpt"])
self._zones = dict((k, ZoneEntry(zone["strict_testing"],
STSResolver(loop=loop,
timeout=zone["timeout"]),
zone["require_sni"],
zone["tlsrpt"]))
for k, zone in cfg["zones"].items())
self._cache = cache
self._children = set()
self._server = None
# Check if cached record is nonexistent or stale
def is_stale(self, cached):
ts = time.time() # pylint: disable=invalid-name
# Nonexistent ?
if cached is None:
return True
# Expired grace period ?
if ts - cached.ts > self._grace:
return True
# Expired policy ?
if cached.pol_body['max_age'] + cached.ts < ts:
return True
return False
async def start(self):
def _spawn(reader, writer):
def done_cb(task, fut):
self._children.discard(task)
task = self._loop.create_task(self.handler(reader, writer))
task.add_done_callback(partial(done_cb, task))
self._children.add(task)
self._logger.debug("len(self._children) = %d", len(self._children))
if self._unix:
self._server = await asyncio.start_unix_server(_spawn, path=self._path)
if self._sockmode is not None:
os.chmod(self._path, self._sockmode)
else:
if self._reuse_port: # pragma: no cover
if sys.platform in ('win32', 'cygwin'):
opts = {
'host': self._host,
'port': self._port,
'reuse_address': True,
}
elif os.name == 'posix':
if sys.platform.startswith('freebsd'):
sockopts = [
(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1),
(socket.SOL_SOCKET, 0x10000, 1), # SO_REUSEPORT_LB
]
sock = await create_custom_socket(self._host, self._port,
options=sockopts)
opts = {
'sock': sock,
}
else:
opts = {
'host': self._host,
'port': self._port,
'reuse_address': True,
'reuse_port': True,
}
self._server = await asyncio.start_server(_spawn, **opts)
async def stop(self):
self._server.close()
await self._server.wait_closed()
while True:
self._logger.warning("Awaiting %d client handlers to finish...",
len(self._children))
remaining = asyncio.gather(*self._children, return_exceptions=True)
self._children.clear()
try:
await asyncio.wait_for(remaining, self._shutdown_timeout)
except asyncio.TimeoutError:
self._logger.warning("Shutdown timeout expired. "
"Remaining handlers terminated.")
try:
await remaining
except asyncio.CancelledError:
pass
await asyncio.sleep(1)
if not self._children:
break
async def sender(self, queue, writer):
def cleanup_queue():
while not queue.empty():
task = queue.get_nowait()
try:
task.cancel()
except Exception: # pragma: no cover
pass
try:
while True:
fut = await queue.get()
# Check for shutdown
if fut is None:
return
self._logger.debug("Got new future from queue")
data = await fut
self._logger.debug("Future await complete: data=%s", repr(data))
writer.write(data)
self._logger.debug("Wrote: %s", repr(data))
await writer.drain()
except asyncio.CancelledError:
cleanup_queue()
except Exception as exc: # pragma: no cover
self._logger.exception("Exception in sender coro: %s", exc)
cleanup_queue()
finally:
writer.close()
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
async def process_request(self, raw_req):
have_policy = True
# Parse request and canonicalize domain
req_zone, _, req_domain = raw_req.decode(REQUEST_ENCODING).partition(' ')
domain = filter_domain(req_domain)
# Skip lookups for parent domain policies
# Skip lookups to non-domains
if domain.startswith('.') or is_ipaddr(domain):
return netstring.encode(b'NOTFOUND ')
# Find appropriate zone config
if req_zone in self._zones:
zone_cfg = self._zones[req_zone]
else:
zone_cfg = self._default_zone
# Lookup for cached policy
try:
cached = await self._cache.get(domain)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Cache get failed: %s", str(exc))
cached = None
# DNS lookup and cache update
if self.is_stale(cached):
ts = time.time() # pylint: disable=invalid-name
self._logger.debug("Lookup PERFORMED: domain = %s", domain)
# Check if newer policy exists or
# retrieve policy from scratch if there is no cached one
latest_pol_id = None if cached is None else cached.pol_id
status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id)
if status is STSFetchResult.NOT_CHANGED:
cached = CacheEntry(ts, cached.pol_id, cached.pol_body)
await self._cache.safe_set(domain, cached, self._logger)
elif status is STSFetchResult.VALID:
pol_id, pol_body = policy
cached = CacheEntry(ts, pol_id, pol_body)
await self._cache.safe_set(domain, cached, self._logger)
else:
if cached is None:
have_policy = False
else:
# Check if cached policy is expired
if cached.pol_body['max_age'] + cached.ts < ts:
have_policy = False
else:
self._logger.debug("Lookup skipped: domain = %s", domain)
if have_policy:
mode = cached.pol_body['mode']
# pylint: disable=no-else-return
if mode == 'none' or (mode == 'testing' and not zone_cfg.strict):
return netstring.encode(b'NOTFOUND ')
else:
assert cached.pol_body['mx'], "Empty MX list for restrictive policy!"
mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])]
resp = "OK secure match=" + ":".join(mxlist)
if zone_cfg.require_sni:
resp += " servername=hostname"
if zone_cfg.tlsrpt:
resp += " policy_type=sts policy_domain=" + domain
resp += " " + " ".join("mx_host_pattern=" + mx for mx in cached.pol_body['mx'])
resp += " " + " ".join(
"{ policy_string = %s: %s }" % (k, v) if k != "mx" else
" ".join("{ policy_string = mx: %s }" % (mx,) for mx in v)
for k, v in cached.pol_body.items())
return netstring.encode(resp.encode('utf-8'))
else:
return netstring.encode(b'NOTFOUND ')
async def handler(self, reader, writer):
# Construct netstring parser
stream_reader = netstring.StreamReader(REQUEST_LIMIT)
# Construct queue for responses ordering
queue = asyncio.Queue(QUEUE_LIMIT)
# Create coroutine which awaits for steady responses and sends them
sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop)
class EndOfStream(Exception):
pass
async def finalize():
try:
await queue.put(None)
except asyncio.CancelledError: # pragma: no cover
sender.cancel()
raise
await sender
try:
while True:
# Extract and parse request
string_reader = stream_reader.next_string()
request_parts = []
while True:
try:
buf = string_reader.read()
except netstring.WantRead:
part = await reader.read(CHUNK)
if not part:
# pylint: disable=raise-missing-from
raise EndOfStream()
self._logger.debug("Read: %s", repr(part))
stream_reader.feed(part)
else:
if buf:
request_parts.append(buf)
else:
req = b''.join(request_parts)
self._logger.debug("Enq request: %s", repr(req))
fut = asyncio.ensure_future(self.process_request(req), loop=self._loop)
await queue.put(fut)
break
except netstring.ParseError:
self._logger.warning("Bad netstring message received")
await finalize()
except (EndOfStream, ConnectionError, TimeoutError):
self._logger.debug("Client disconnected")
await finalize()
except OSError as exc: # pragma: no cover
if exc.errno == 107:
self._logger.debug("Client disconnected")
await finalize()
else:
self._logger.exception("Unhandled exception: %s", exc)
await finalize()
except asyncio.CancelledError:
sender.cancel()
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Unhandled exception: %s", exc)
await finalize()