Skip to content

Commit f5fe0a6

Browse files
committed
fix: corrections on typing and attestation parameters
1 parent 0ae9354 commit f5fe0a6

File tree

10 files changed

+46
-39
lines changed

10 files changed

+46
-39
lines changed

nilai-api/gunicorn.conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
bind = ["0.0.0.0:8080"]
66

77
# Set the number of workers (2)
8-
workers = SETTINGS["gunicorn_workers"]
8+
workers = SETTINGS.gunicorn_workers
99

1010
# Set the number of threads per worker (16)
1111
threads = 1

nilai-api/src/nilai_api/attestation/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def get_attestation_report(
1212
"""Get the attestation report for the given nonce"""
1313

1414
try:
15-
attestation_url = f"http://{SETTINGS['attestation_host']}:{SETTINGS['attestation_port']}/attestation/report"
15+
attestation_url = f"http://{SETTINGS.attestation_host}:{SETTINGS.attestation_port}/attestation/report"
1616
async with httpx.AsyncClient() as client:
1717
response: httpx.Response = await client.get(attestation_url, params=nonce)
1818
report = AttestationReport(**response.json())
@@ -24,10 +24,10 @@ async def get_attestation_report(
2424
async def verify_attestation_report(attestation_report: AttestationReport) -> bool:
2525
"""Verify the attestation report"""
2626
try:
27-
attestation_url = f"http://{SETTINGS['attestation_host']}:{SETTINGS['attestation_port']}/attestation/verify"
27+
attestation_url = f"http://{SETTINGS.attestation_host}:{SETTINGS.attestation_port}/attestation/verify"
2828
async with httpx.AsyncClient() as client:
29-
response: httpx.Response = await client.post(
30-
attestation_url, json=attestation_report.model_dump()
29+
response: httpx.Response = await client.get(
30+
attestation_url, params=attestation_report.model_dump()
3131
)
3232
return response.json()
3333
except Exception as e:

nilai-attestation/gunicorn.conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from nilai_common.config import SETTINGS
33

44
# Bind to address and port
5-
bind = [f"0.0.0.0:{SETTINGS['attestation_port']}"]
5+
bind = [f"0.0.0.0:{SETTINGS.attestation_port}"]
66

77
# Set the number of workers (2)
8-
workers = SETTINGS["gunicorn_workers"]
8+
workers = SETTINGS.gunicorn_workers
99

1010
# Set the number of threads per worker (16)
1111
threads = 1

nilai-attestation/src/nilai_attestation/attestation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_attestation_report(nonce: Nonce | None = None) -> AttestationReport:
2424
logger.info(f"Nonce: {attestation_nonce}")
2525

2626
load_sev_library()
27+
2728
return AttestationReport(
2829
nonce=attestation_nonce,
2930
verifying_key="",

nilai-attestation/src/nilai_attestation/attestation/nvtrust/nv_attester.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def is_nvidia_gpu_available() -> bool:
4646
return False
4747

4848

49-
def nv_attest(nonce: Nonce) -> NVAttestationToken:
49+
def nv_attest(nonce: Nonce, name: str = "thisNode1") -> NVAttestationToken:
5050
"""Generate an attestation token from local evidence.
5151
5252
Args:
@@ -57,7 +57,7 @@ def nv_attest(nonce: Nonce) -> NVAttestationToken:
5757
"""
5858
# Create and configure the attestation client.
5959
client = attestation.Attestation()
60-
client.set_name("thisNode1")
60+
client.set_name(name)
6161
client.set_nonce(nonce)
6262

6363
logger.info("Checking if NVIDIA GPU is available")

nilai-attestation/src/nilai_attestation/attestation/nvtrust/nv_verifier.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
}
4444

4545

46-
def verify_attestation(attestation_report: AttestationReport) -> bool:
46+
def verify_attestation(
47+
attestation_report: AttestationReport, name: str = "thisNode1"
48+
) -> bool:
4749
"""Verify an NVIDIA attestation token against a policy.
4850
4951
Args:
@@ -57,7 +59,7 @@ def verify_attestation(attestation_report: AttestationReport) -> bool:
5759
# Create an attestation client instance for token verification.
5860
logger.info(f"Attestation report: {attestation_report}")
5961
client = attestation.Attestation()
60-
client.set_name("thisNode1")
62+
client.set_name(name)
6163
client.set_nonce(attestation_report.nonce)
6264
client.add_verifier(
6365
attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, ""

nilai-attestation/src/nilai_attestation/routers/private.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Fast API and serving
22
import logging
3-
from fastapi import APIRouter
3+
from fastapi import APIRouter, Depends
44

55
# Internal libraries
66
from nilai_attestation.attestation import (
@@ -35,12 +35,14 @@ async def get_attestation(nonce: Nonce | None = None) -> AttestationReport:
3535
return get_attestation_report(nonce)
3636

3737

38-
@router.post("/attestation/verify", tags=["Attestation"])
39-
async def post_attestation(attestation_report: AttestationReport) -> bool:
38+
@router.get("/attestation/verify", tags=["Attestation"])
39+
async def get_attestation_verification(
40+
attestation_report: AttestationReport = Depends(),
41+
) -> bool:
4042
"""
41-
Verify a cryptographic attestation report.
43+
Verify a cryptographic attestation report passed as query parameters.
4244
43-
- **attestation_report**: Attestation report to verify
45+
- **attestation_report**: Attestation report to verify (fields passed as query parameters)
4446
- **Returns**: True if the attestation report is valid, False otherwise
4547
"""
4648
return verify_attestation_report(attestation_report)

nilai-models/src/nilai_models/daemon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async def get_metadata(num_retries=30):
2121
while True:
2222
url = None
2323
try:
24-
url = f"http://{SETTINGS['host']}:{SETTINGS['port']}/v1/models"
24+
url = f"http://{SETTINGS.host}:{SETTINGS.port}/v1/models"
2525
# Request model metadata from localhost:8000/v1/models
2626
async with httpx.AsyncClient() as client:
2727
response = await client.get(url)
@@ -37,7 +37,7 @@ async def get_metadata(num_retries=30):
3737
license="Apache 2.0", # Usage license
3838
source=f"https://huggingface.co/{model_name}", # Model source
3939
supported_features=["chat_completion"], # Capabilities
40-
tool_support=SETTINGS["tool_support"], # Tool support
40+
tool_support=SETTINGS.tool_support, # Tool support
4141
)
4242

4343
except Exception as e:
@@ -81,12 +81,12 @@ async def main():
8181
try:
8282
# Initialize discovery service
8383
discovery_service = ModelServiceDiscovery(
84-
host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"]
84+
host=SETTINGS.etcd_host, port=SETTINGS.etcd_port
8585
)
8686

8787
metadata = await get_metadata()
8888
model_endpoint = ModelEndpoint(
89-
url=f"http://{SETTINGS['host']}:{SETTINGS['port']}", metadata=metadata
89+
url=f"http://{SETTINGS.host}:{SETTINGS.port}", metadata=metadata
9090
)
9191

9292
# Setup signal handlers
Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import os
2+
from pydantic import BaseModel
23

3-
# from dotenv import load_dotenv
44

5-
# load_dotenv() # Only needed locally if using a .env file
5+
class HostSettings(BaseModel):
6+
host: str = "localhost"
7+
port: int = 8000
8+
etcd_host: str = "localhost"
9+
etcd_port: int = 2379
10+
tool_support: bool = False
11+
gunicorn_workers: int = 10
12+
attestation_host: str = "localhost"
13+
attestation_port: int = 8081
614

7-
SETTINGS = {
8-
"host": os.getenv("SVC_HOST", "localhost"),
9-
"port": os.getenv("SVC_PORT", 8000),
10-
"etcd_host": os.getenv("ETCD_HOST", "localhost"),
11-
"etcd_port": os.getenv("ETCD_PORT", 2379),
12-
"tool_support": os.getenv("TOOL_SUPPORT", False),
13-
"gunicorn_workers": os.getenv("NILAI_GUNICORN_WORKERS", 10),
14-
"attestation_host": os.getenv("ATTESTATION_HOST", "localhost"),
15-
"attestation_port": os.getenv("ATTESTATION_PORT", 8080),
16-
}
17-
# if environment == "docker":
18-
# config = "docker_settings.py"
19-
# else:
20-
# config = "local_settings.py"
2115

22-
# # Import the appropriate config dynamically
23-
# from importlib import import_module
24-
# settings = import_module(config)
16+
SETTINGS: HostSettings = HostSettings(
17+
host=str(os.getenv("SVC_HOST", "localhost")),
18+
port=int(os.getenv("SVC_PORT", 8000)),
19+
etcd_host=str(os.getenv("ETCD_HOST", "localhost")),
20+
etcd_port=int(os.getenv("ETCD_PORT", 2379)),
21+
tool_support=bool(os.getenv("TOOL_SUPPORT", False)),
22+
gunicorn_workers=int(os.getenv("NILAI_GUNICORN_WORKERS", 10)),
23+
attestation_host=str(os.getenv("ATTESTATION_HOST", "localhost")),
24+
attestation_port=int(os.getenv("ATTESTATION_PORT", 8081)),
25+
)

tests/e2e/test_http.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def client():
2525
"Content-Type": "application/json",
2626
"Authorization": f"Bearer {AUTH_TOKEN}",
2727
},
28+
timeout=None,
2829
)
2930

3031

0 commit comments

Comments
 (0)