Skip to content

Commit 30c0a61

Browse files
author
Ubuntu
committed
remote attestation for openfl participants
Signed-off-by: Ubuntu <azureuser@ofl-dev-vm-ad-anshumi1.qnxiewjiflyubbpcwut13wv1wh.cx.internal.cloudapp.net>
1 parent 57b21ac commit 30c0a61

File tree

13 files changed

+791
-17
lines changed

13 files changed

+791
-17
lines changed

openfl-docker/gramine_app/fx.manifest.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,5 @@ sgx.allowed_files = [
7171
"file:{{ workspace_root }}/plan/cols.yaml",
7272
"file:{{ workspace_root }}/plan/data.yaml",
7373
"file:{{ workspace_root }}/plan/plan.yaml",
74+
"file:{{ workspace_root }}/attestation",
7475
]

openfl-workspace/workspace/plan/defaults/aggregator.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ settings :
66
last_state_path : save/last.pbuf
77
persist_checkpoint: True
88
persistent_db_path: local_state/tensor.db
9+
enable_remote_attestation : False

openfl-workspace/workspace/plan/defaults/collaborator.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ settings :
33
opt_treatment : 'CONTINUE_LOCAL'
44
use_delta_updates : True
55
db_store_rounds : 1
6+
enable_remote_attestation : False

openfl/component/aggregator/aggregator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
persist_checkpoint=True,
9292
persistent_db_path=None,
9393
secure_aggregation=False,
94+
enable_remote_attestation=False,
9495
):
9596
"""Initializes the Aggregator.
9697
@@ -146,6 +147,7 @@ def __init__(
146147
self.uuid = aggregator_uuid
147148
self.federation_uuid = federation_uuid
148149
self.connector = connector
150+
self.enable_remote_attestation = enable_remote_attestation
149151

150152
self.quit_job_sent_to = []
151153

openfl/cryptography/signer.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import base64
2+
import os
3+
from datetime import datetime, timedelta
4+
5+
from cryptography import x509
6+
from cryptography.hazmat.backends import default_backend
7+
from cryptography.hazmat.primitives import hashes, serialization
8+
from cryptography.hazmat.primitives.asymmetric import ec
9+
from cryptography.hazmat.primitives.serialization import load_pem_private_key
10+
from cryptography.x509.oid import NameOID
11+
12+
13+
class ECDSASigner:
14+
"""ECDSA secp384R1 signer
15+
16+
This class implements ECDSA methods specific to a single instance of
17+
enclave, so, that it can be reused across without passing references
18+
across the breadth of the code. The relevant keys and certificates are
19+
stored in /tmp. /tmp should be appropriately configured in the enclave
20+
manifest, so, that it is wiped off when the enclave exits.
21+
22+
Raises:
23+
Exception: ValueError() for incorrect configuration
24+
25+
Returns:
26+
object: The reference of the only object
27+
"""
28+
29+
__signer_instance = None
30+
31+
@staticmethod
32+
def get_instance(privkey_path="/tmp/client_privkey.pem"):
33+
if ECDSASigner.__signer_instance is None:
34+
ECDSASigner(privkey_path)
35+
return ECDSASigner.__signer_instance
36+
37+
def __init__(self, privkey_path="/tmp/client_privkey.pem", cert_path="/tmp/"):
38+
"""Constructor for creating the ECDSA Certificate Chain
39+
40+
Args:
41+
privkey_path (string, optional): Path to the existing client private key.
42+
Defaults to None.
43+
44+
Raises:
45+
Exception: Generic, in case there's invalid configuration or file data
46+
"""
47+
if ECDSASigner.__signer_instance is not None:
48+
raise Exception("ECDSASigner: Only one instance allowed")
49+
else:
50+
ECDSASigner.__signer_instance = self
51+
52+
self._root_cert_path = os.path.join(cert_path, "openfl-security-ca-cert.pem")
53+
self._client_cert_path = os.path.join(cert_path, "openfl-enclave-cert.pem")
54+
55+
# If the private key already exists, then reuse to create the certificate
56+
# if doesn't exist.
57+
self._client_cert = None
58+
self._root_cert = None
59+
60+
if privkey_path and os.path.exists(privkey_path):
61+
with open(privkey_path, "rb") as client_priv_fh:
62+
client_privkey_pem = client_priv_fh.read()
63+
64+
# Once the private key is found in filesystem, then look for saved
65+
# certificate and the corresponding public key
66+
self._client_privkey = load_pem_private_key(client_privkey_pem, password=None)
67+
if not isinstance(self._client_privkey, ec.EllipticCurvePrivateKey):
68+
raise ValueError(f"Invalid private key format: '{privkey_path}'")
69+
70+
self._client_pubkey = self._client_privkey.public_key()
71+
72+
# Check if certificate is already present in the filesystem, if not then
73+
# serialize is not called yet
74+
if os.path.exists(self._client_cert_path):
75+
with open(self._client_cert_path, "rb") as cert_fh:
76+
self._client_cert = x509.load_pem_x509_certificate(
77+
cert_fh.read(), default_backend()
78+
)
79+
80+
# FIXME: Post upgrading cryptography module, delete this below
81+
# and change above call to load_pem_x509_certificate(s) to load the chain
82+
if not os.path.exists(self._root_cert_path):
83+
raise ValueError(
84+
"Out of tree modification detected, "
85+
"clean all the keys and certs and try again"
86+
)
87+
with open(self._root_cert_path, "rb") as cert_fh:
88+
self._root_cert = x509.load_pem_x509_certificate(
89+
cert_fh.read(), default_backend()
90+
)
91+
92+
else:
93+
self._root_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend())
94+
self._root_pubkey = self._root_privkey.public_key()
95+
96+
self._client_privkey = ec.generate_private_key(ec.SECP384R1(), default_backend())
97+
98+
self._client_pubkey = self._client_privkey.public_key()
99+
100+
def __get_cert(
101+
self,
102+
subject_name,
103+
subject_pubkey,
104+
issuer_name,
105+
issuer_privkey,
106+
ca=False,
107+
mrenclave_data=None,
108+
):
109+
"""Create a certificate with optional MRENCLAVE_OID extension.
110+
111+
Args:
112+
subject_name (string): The subject name in the certificate, must match the URL.
113+
subject_pubkey (string): Subject's public key to be embedded in the certificate.
114+
issuer_name (string): The CA name.
115+
issuer_privkey (string): The CA private key to sign the subject certificate.
116+
ca (bool, optional): To set CA=true property. Defaults to False.
117+
mrenclave_data (string, optional): The mrenclave data to be added as a custom extension.
118+
119+
Returns:
120+
object: Certificate.
121+
"""
122+
# Create the certificate builder
123+
cert_builder = x509.CertificateBuilder()
124+
cert_builder = cert_builder.subject_name(
125+
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, subject_name)])
126+
)
127+
cert_builder = cert_builder.issuer_name(
128+
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, issuer_name)])
129+
)
130+
131+
oneday = timedelta(1, 0, 0)
132+
start_validity = datetime.today() - oneday
133+
end_validity = datetime.today() + (100 * oneday)
134+
cert_builder = cert_builder.not_valid_before(start_validity)
135+
cert_builder = cert_builder.not_valid_after(end_validity)
136+
137+
if ca:
138+
cert_builder = cert_builder.add_extension(
139+
x509.BasicConstraints(ca=True, path_length=None), critical=True
140+
)
141+
142+
cert_builder = cert_builder.serial_number(x509.random_serial_number())
143+
cert_builder = cert_builder.public_key(subject_pubkey)
144+
145+
# Add the custom MRENCLAVE_OID extension if mrenclave_data is provided
146+
if mrenclave_data:
147+
MRENCLAVE_OID = x509.ObjectIdentifier("1.3.6.1.4.1.99999.1.1") # Example OID
148+
cert_builder = cert_builder.add_extension(
149+
x509.UnrecognizedExtension(MRENCLAVE_OID, mrenclave_data.encode("utf-8")),
150+
critical=False,
151+
)
152+
153+
# Sign the certificate
154+
return cert_builder.sign(issuer_privkey, algorithm=hashes.SHA384())
155+
156+
def __create_cert_chain(self):
157+
root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM))
158+
client_cert_bytes = self._client_cert.public_bytes((serialization.Encoding.PEM))
159+
160+
# Concatenate together and return
161+
cert_chain = client_cert_bytes.decode("utf-8") + root_cert_bytes.decode("utf-8")
162+
return cert_chain
163+
164+
def cert(self, common_name, mrenclave_data=None):
165+
"""Returns the self-signed certificate
166+
Args:
167+
common_name (string): to be used as subject and issuer name
168+
Returns:
169+
_type_: _description_
170+
"""
171+
172+
# Create the certificate chaining with local CA
173+
# A. Create a root CA certificate
174+
# B. Create the node certificate
175+
176+
# Return the cached value if it exists
177+
if self._client_cert:
178+
return self.__create_cert_chain()
179+
180+
ca_common_name = f"{common_name}-CA"
181+
self._root_cert = self.__get_cert(
182+
ca_common_name, self._root_pubkey, ca_common_name, self._root_privkey, ca=True
183+
)
184+
185+
self._client_cert = self.__get_cert(
186+
common_name,
187+
self._client_pubkey,
188+
ca_common_name,
189+
self._root_privkey,
190+
False,
191+
mrenclave_data=mrenclave_data,
192+
)
193+
194+
return self.__create_cert_chain()
195+
196+
def get_pubkey(self):
197+
"""returns public key as a PEM string"""
198+
199+
return self._client_pubkey.public_bytes(
200+
encoding=serialization.Encoding.PEM,
201+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
202+
).decode("utf-8")
203+
204+
def sign(self, message):
205+
"""sign message string using private key.
206+
207+
Return Value: base64 encoded string
208+
"""
209+
210+
signature_bytes = self._client_privkey.sign(
211+
message.encode("utf-8"), ec.ECDSA(hashes.SHA384())
212+
)
213+
return base64.b64encode(signature_bytes).decode("utf-8")
214+
215+
def serialize_private_key(self, filename="/tmp/client_privkey.pem", save_root_cert=True):
216+
"""write the private key to a file in PEM format"""
217+
218+
client_privkey_pem = self._client_privkey.private_bytes(
219+
encoding=serialization.Encoding.PEM,
220+
format=serialization.PrivateFormat.PKCS8,
221+
encryption_algorithm=serialization.NoEncryption(),
222+
)
223+
224+
with open(filename, "wb") as fh:
225+
fh.write(client_privkey_pem)
226+
227+
# Save the CA certificate if not already exists
228+
if os.path.exists(self._root_cert_path) is False and save_root_cert:
229+
root_cert_bytes = self._root_cert.public_bytes((serialization.Encoding.PEM))
230+
with open(self._root_cert_path, "wb") as fh:
231+
fh.write(root_cert_bytes)
232+
233+
# Save the client certificate if not already exists
234+
if os.path.exists(self._client_cert_path) is False:
235+
cert_chain = self.__create_cert_chain()
236+
with open(self._client_cert_path, "wb") as fh:
237+
fh.write(cert_chain.encode("utf-8"))
238+
239+
return client_privkey_pem

openfl/federated/plan/plan.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""Plan module."""
66

7+
import os
78
from functools import partial
89
from hashlib import sha384
910
from importlib import import_module
@@ -625,6 +626,11 @@ def get_client_args(
625626
certificate = f"cert/client/col_{common_name}.crt"
626627
private_key = f"cert/client/col_{common_name}.key"
627628

629+
# added to test that new root certs generated by aggregator are used by clients
630+
#if os.getenv("ROOT_CERT_PATH", None) is not None:
631+
# cert_path = os.getenv("ROOT_CERT_PATH")
632+
# root_certificate = cert_path + "/cert_chain.crt"
633+
628634
client_args = self.config["network"][SETTINGS]
629635

630636
# patch certificates
@@ -643,6 +649,7 @@ def get_server(
643649
root_certificate=None,
644650
private_key=None,
645651
certificate=None,
652+
attested_identity=None,
646653
**kwargs,
647654
):
648655
"""Get gRPC or REST server of the aggregator instance.
@@ -659,7 +666,9 @@ def get_server(
659666
Returns:
660667
Aggregator Server: returns either gRPC or REST server of the aggregator instance.
661668
"""
662-
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)
669+
server_args = self.get_server_args(
670+
root_certificate, private_key, certificate, attested_identity, kwargs
671+
)
663672

664673
server_args["aggregator"] = self.get_aggregator()
665674
network_cfg = self.config["network"][SETTINGS]
@@ -679,14 +688,31 @@ def _get_server(self, protocol, **kwargs):
679688
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
680689
return server
681690

682-
def get_server_args(self, root_certificate, private_key, certificate, kwargs):
691+
def get_server_args(
692+
self, root_certificate, private_key, certificate, attested_identity, kwargs
693+
):
683694
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()
684695

685-
if not root_certificate or not private_key or not certificate:
696+
if not root_certificate or not private_key or not certificate :
686697
root_certificate = "cert/cert_chain.crt"
687698
certificate = f"cert/server/agg_{common_name}.crt"
688699
private_key = f"cert/server/agg_{common_name}.key"
689700

701+
if attested_identity:
702+
cert_chain = os.path.join(
703+
os.path.dirname(attested_identity.get_root_cert_path()),
704+
"cert_chain.crt",
705+
)
706+
logger.info("building root cert: %s", cert_chain)
707+
attested_identity.build_root_cert("cert/client/")
708+
attested_identity.save_cert(f"cert/server/agg_{common_name}_self_signed.crt")
709+
attested_identity.save_private_key(f"/tmp/agg_{common_name}.key")
710+
root_certificate = cert_chain
711+
certificate = f"cert/server/agg_{common_name}_self_signed.crt"
712+
private_key = f"/tmp/agg_{common_name}.key"
713+
else:
714+
logger.info("Remote attestation is not enabled. Using default certificates.")
715+
690716
server_args = self.config["network"][SETTINGS]
691717

692718
# patch certificates

openfl/interface/aggregator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from openfl.federated import Plan
3333
from openfl.interface.cli_helper import CERT_DIR
3434
from openfl.utilities import click_types
35+
from openfl.utilities.attestation import attestation_utils as attestation_utils
3536
from openfl.utilities.path_check import is_directory_traversal
3637
from openfl.utilities.utils import getfqdn_env
3738

@@ -91,8 +92,16 @@ def start_(plan, authorized_cols, task_group):
9192
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
9293
logger.info(f"Setting aggregator to assign: {task_group} task_group")
9394

95+
# check if remote attestation is enabled
96+
attested_identity = None
97+
if parsed_plan.config["aggregator"]["settings"].get("enable_remote_attestation", False):
98+
# check if the aggregator is running in a remote attestation environment
99+
attested_identity = attestation_utils.get_remote_attestation("aggregator")
100+
else:
101+
logger.info("Remote attestation is not enabled.")
102+
94103
logger.info("🧿 Starting the Aggregator Service.")
95-
server = parsed_plan.get_server()
104+
server = parsed_plan.get_server(attested_identity=attested_identity)
96105
server.serve()
97106

98107

openfl/interface/collaborator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from openfl.federated import Plan
2525
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
2626
from openfl.interface.cli_helper import CERT_DIR
27+
from openfl.utilities.attestation import attestation_utils as attestation_utils
2728
from openfl.utilities.path_check import is_directory_traversal
2829
from openfl.utilities.utils import rmtree
2930

@@ -78,7 +79,8 @@ def start_(plan, collaborator_name, data_config):
7879

7980
# TODO: Need to restructure data loader config file loader
8081
logger.info(f"Data paths: {plan_obj.cols_data_paths}")
81-
echo(f"Data = {plan_obj.cols_data_paths}")
82+
# this check is added to avoid mock objects failing
83+
8284
logger.info("🧿 Starting a Collaborator Service.")
8385

8486
collaborator = plan_obj.get_collaborator(collaborator_name)

openfl/utilities/attestation/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)