Skip to content

Commit b09bc0d

Browse files
committed
Safe Replace DNSStore content and DNSZone content
1 parent e990289 commit b09bc0d

File tree

3 files changed

+160
-10
lines changed

3 files changed

+160
-10
lines changed

aiomisc/service/dns/store.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Tuple
1+
from typing import Iterable, Mapping, Optional, Sequence, Tuple
22

33
from .records import DNSRecord, RecordType
44
from .tree import RadixTree
@@ -49,3 +49,22 @@ def get_zone_for_name(self, name: str) -> Optional[Tuple[str, ...]]:
4949
@staticmethod
5050
def get_reverse_tuple(zone_name: str) -> Tuple[str, ...]:
5151
return tuple(zone_name.strip(".").split("."))[::-1]
52+
53+
def replace(
54+
self, zones_data: Mapping[str, Iterable[DNSRecord]],
55+
) -> None:
56+
"""
57+
Atomically replace all zones with new ones this method is safe
58+
because it replaces all zones at once. zone_data is a mapping
59+
zone name and a sequence of DNSRecord objects.
60+
61+
If any of the zones or records is invalid, nothing will be replaced.
62+
63+
This method is useful for reload configuration from disk
64+
or database or etc.
65+
"""
66+
new_zones: RadixTree[DNSZone] = RadixTree()
67+
for zone_name, records in zones_data.items():
68+
zone = DNSZone(zone_name, *records)
69+
new_zones.insert(self.get_reverse_tuple(zone.name), zone)
70+
self.zones = new_zones

aiomisc/service/dns/zone.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
11
from collections import defaultdict
2-
from typing import DefaultDict, Sequence, Set, Tuple
2+
from typing import DefaultDict, Iterable, Sequence, Set, Tuple
33

44
from .records import DNSRecord, RecordType
55

66

7+
RecordsType = DefaultDict[Tuple[str, RecordType], Set[DNSRecord]]
8+
9+
710
class DNSZone:
8-
records: DefaultDict[Tuple[str, RecordType], Set[DNSRecord]]
11+
records: RecordsType
912
name: str
1013

1114
__slots__ = ("name", "records")
1215

13-
def __init__(self, name: str):
16+
def __init__(self, name: str, *records: DNSRecord) -> None:
1417
if not name.endswith("."):
1518
name += "."
1619
self.name = name
1720
self.records = defaultdict(set)
1821

22+
for record in records:
23+
self.add_record(record)
24+
1925
def add_record(self, record: DNSRecord) -> None:
20-
if not self._is_valid_record(record):
26+
if not self.check_record(record):
2127
raise ValueError(
2228
f"Record {record.name} does not belong to zone {self.name}",
2329
)
24-
key = (record.name, record.type)
25-
self.records[key].add(record)
30+
self.records[(record.name, record.type)].add(record)
2631

2732
def remove_record(self, record: DNSRecord) -> None:
2833
key = (record.name, record.type)
@@ -37,8 +42,26 @@ def get_records(
3742
) -> Sequence[DNSRecord]:
3843
if not name.endswith("."):
3944
name += "."
40-
key = (name, record_type)
41-
return tuple(self.records.get(key, ()))
45+
return tuple(self.records.get((name, record_type), ()))
4246

43-
def _is_valid_record(self, record: DNSRecord) -> bool:
47+
def check_record(self, record: DNSRecord) -> bool:
4448
return record.name.endswith(self.name)
49+
50+
def replace(self, records: Iterable[DNSRecord]) -> None:
51+
"""
52+
Atomically replace all records in specified zone with new ones.
53+
This method is safe because it replaces all records at once.
54+
55+
If any of the records does not belong to the zone, ValueError
56+
will be raised and no records will be replaced.
57+
"""
58+
new_records: RecordsType = defaultdict(set)
59+
60+
for record in records:
61+
if not self.check_record(record):
62+
raise ValueError(
63+
f"Record {record.name} does not "
64+
f"belong to zone {self.name}",
65+
)
66+
new_records[(record.name, record.type)].add(record)
67+
self.records = new_records

tests/test_dns.py

+108
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,111 @@ def test_sshfp_create():
516516
assert record.data.fp_type == 1
517517
assert record.data.fingerprint == b"abcdefg"
518518
assert record.ttl == 300
519+
520+
521+
def test_zone_replace(dns_store):
522+
zone = DNSZone("example.com.")
523+
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
524+
record2 = A.create(name="api.example.com.", ip="192.0.2.2")
525+
zone.add_record(record1)
526+
dns_store.add_zone(zone)
527+
528+
zone.replace([record2])
529+
530+
records = dns_store.query("www.example.com.", RecordType.A)
531+
assert len(records) == 0
532+
records = dns_store.query("api.example.com.", RecordType.A)
533+
assert len(records) == 1
534+
assert record2 in records
535+
536+
537+
def test_zone_replace_multiple_records():
538+
zone = DNSZone("example.com.")
539+
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
540+
record2 = A.create(name="www.example.com.", ip="192.0.2.2")
541+
542+
zone.replace([record1, record2])
543+
records = zone.get_records("www.example.com.", RecordType.A)
544+
assert len(records) == 2
545+
assert record1 in records
546+
assert record2 in records
547+
548+
549+
def test_zone_replace_empty():
550+
zone = DNSZone("example.com.")
551+
record = A.create(name="www.example.com.", ip="192.0.2.1")
552+
zone.add_record(record)
553+
554+
zone.replace([])
555+
records = zone.get_records("www.example.com.", RecordType.A)
556+
assert len(records) == 0
557+
558+
559+
def test_zone_replace_invalid_record():
560+
zone = DNSZone("example.com.")
561+
record = A.create(name="www.other.com.", ip="192.0.2.1")
562+
563+
with pytest.raises(ValueError, match="does not belong to zone"):
564+
zone.replace([record])
565+
566+
567+
def test_store_replace_basic(dns_store):
568+
zone1 = DNSZone("example.com.")
569+
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
570+
zone1.add_record(record1)
571+
dns_store.add_zone(zone1)
572+
573+
zone2 = DNSZone("test.com.")
574+
record2 = A.create(name="www.test.com.", ip="192.0.2.2")
575+
zone2.add_record(record2)
576+
dns_store.add_zone(zone2)
577+
578+
# Replace with new data
579+
new_record1 = A.create(name="api.example.com.", ip="192.0.2.3")
580+
new_record2 = A.create(name="api.test.com.", ip="192.0.2.4")
581+
582+
dns_store.replace({
583+
"example.com.": [new_record1],
584+
"test.com.": [new_record2],
585+
})
586+
587+
# Check old records are gone
588+
records = dns_store.query("www.example.com.", RecordType.A)
589+
assert len(records) == 0
590+
records = dns_store.query("www.test.com.", RecordType.A)
591+
assert len(records) == 0
592+
593+
# Check new records are present
594+
records = dns_store.query("api.example.com.", RecordType.A)
595+
assert len(records) == 1
596+
assert new_record1 in records
597+
records = dns_store.query("api.test.com.", RecordType.A)
598+
assert len(records) == 1
599+
assert new_record2 in records
600+
601+
602+
def test_store_replace_empty(dns_store):
603+
zone = DNSZone("example.com.")
604+
record = A.create(name="www.example.com.", ip="192.0.2.1")
605+
zone.add_record(record)
606+
dns_store.add_zone(zone)
607+
608+
dns_store.replace({})
609+
610+
assert dns_store.get_zone("example.com.") is None
611+
records = dns_store.query("www.example.com.", RecordType.A)
612+
assert len(records) == 0
613+
614+
615+
def test_store_replace_multiple_records_per_zone(dns_store):
616+
new_record1 = A.create(name="www.example.com.", ip="192.0.2.1")
617+
new_record2 = A.create(name="www.example.com.", ip="192.0.2.2")
618+
619+
dns_store.replace({
620+
"example.com.": [new_record1, new_record2],
621+
})
622+
623+
records = dns_store.query("www.example.com.", RecordType.A)
624+
assert len(records) == 2
625+
assert new_record1 in records
626+
assert new_record2 in records

0 commit comments

Comments
 (0)