Skip to content

Commit 4bd80a4

Browse files
committed
feat: improved user id grabbing from root nuc
1 parent 022c413 commit 4bd80a4

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

nilai-api/src/nilai_api/credit.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from nilai_api.config import CONFIG
1414

15+
from nuc.envelope import NucTokenEnvelope
16+
1517
logger = logging.getLogger(__name__)
1618

1719

@@ -107,6 +109,24 @@ async def wrapper(request: Request) -> str:
107109
return wrapper
108110

109111

112+
def from_nuc_bearer_root_token() -> Callable[[Request], Awaitable[str]]:
113+
"""Extract user ID from a NUC root token"""
114+
115+
async def extractor(request: Request) -> str:
116+
auth_header: str | None = request.headers.get("Authorization", None)
117+
if not auth_header or not auth_header.startswith("Bearer "):
118+
raise ValueError("No Bearer token found")
119+
120+
# Remove the Bearer prefix
121+
token_str: str = auth_header.replace("Bearer ", "")
122+
# Parse the token
123+
token = NucTokenEnvelope.parse(token_str)
124+
# Returns the issuer of the root token from an invocation token
125+
return str(token.proofs[-1].token.issuer)
126+
127+
return extractor
128+
129+
110130
def llm_cost_calculator(llm_cost_dict: LLMCostDict):
111131
async def calculator(request: Request, response_data: dict) -> float:
112132
model_name = getattr(request, "model", "default")
@@ -125,7 +145,7 @@ async def calculator(request: Request, response_data: dict) -> float:
125145

126146

127147
LLMMeter = create_metering_dependency(
128-
user_id_extractor=user_id_extractor(),
148+
user_id_extractor=from_nuc_bearer_root_token(),
129149
estimated_cost=2.0,
130150
cost_calculator=llm_cost_calculator(MyCostDictionary),
131151
)

0 commit comments

Comments
 (0)