11# nilai/models/model.py
22import asyncio
3- import signal
43import logging
4+ import signal
5+
56import httpx
67
7- from nilai_common import ( # Model service discovery and host settings
8- SETTINGS ,
8+ from nilai_common import (
99 MODEL_SETTINGS ,
10- ModelServiceDiscovery ,
10+ SETTINGS ,
1111 ModelEndpoint ,
1212 ModelMetadata ,
13+ ModelServiceDiscovery ,
1314)
1415
1516logger = logging .getLogger (__name__ )
1617
1718
1819async def get_metadata ():
19- """Fetch model metadata from model
20- service and return as ModelMetadata object"""
20+ """Fetch model metadata from model service and return as ModelMetadata object."""
2121 current_retries = 0
2222 while True :
2323 url = None
2424 try :
2525 url = f"http://{ SETTINGS .host } :{ SETTINGS .port } /v1/models"
26- # Request model metadata from localhost:8000/v1/models
2726 async with httpx .AsyncClient () as client :
2827 response = await client .get (url )
2928 response .raise_for_status ()
3029 response_data = response .json ()
3130 model_name = response_data ["data" ][0 ]["id" ]
3231 return ModelMetadata (
33- id = model_name , # Unique identifier
34- name = model_name , # Human-readable name
35- version = "1.0" , # Model version
32+ id = model_name ,
33+ name = model_name ,
34+ version = "1.0" ,
3635 description = "" ,
37- author = "" , # Model creators
38- license = "Apache 2.0" , # Usage license
39- source = f"https://huggingface.co/{ model_name } " , # Model source
40- supported_features = ["chat_completion" ], # Capabilities
41- tool_support = SETTINGS .tool_support , # Tool support
42- multimodal_support = SETTINGS .multimodal_support , # Multimodal support
36+ author = "" ,
37+ license = "Apache 2.0" ,
38+ source = f"https://huggingface.co/{ model_name } " ,
39+ supported_features = ["chat_completion" ],
40+ tool_support = SETTINGS .tool_support ,
41+ multimodal_support = SETTINGS .multimodal_support ,
4342 )
4443
4544 except Exception as e :
@@ -49,16 +48,16 @@ async def get_metadata():
4948 logger .warning (f"Failed to fetch model metadata from { url } : { e } " )
5049 current_retries += 1
5150 if (
52- MODEL_SETTINGS .num_retries
53- != - 1 # If num_retries == -1 then we do infinite number of retries
51+ MODEL_SETTINGS .num_retries != - 1
5452 and current_retries >= MODEL_SETTINGS .num_retries
5553 ):
5654 raise e
5755 await asyncio .sleep (MODEL_SETTINGS .timeout )
5856
5957
6058async def run_service (discovery_service , model_endpoint ):
61- """Runs the model service and keeps it alive"""
59+ """Register model with discovery service and keep it alive."""
60+ lease = None
6261 try :
6362 logger .info (f"Registering model: { model_endpoint .metadata .id } " )
6463 lease = await discovery_service .register_model (model_endpoint , prefix = "/models" )
@@ -73,50 +72,62 @@ async def run_service(discovery_service, model_endpoint):
7372 logger .error (f"Service error: { e } " )
7473 raise
7574 finally :
76- try :
77- await discovery_service .unregister_model (model_endpoint .metadata .id )
78- logger .info (f"Model unregistered: { model_endpoint .metadata .id } " )
79- except Exception as e :
80- logger .error (f"Error unregistering model: { e } " )
75+ if lease :
76+ try :
77+ await discovery_service .unregister_model (model_endpoint .metadata .id )
78+ logger .info (f"Model unregistered: { model_endpoint .metadata .id } " )
79+ except Exception as e :
80+ logger .error (f"Error unregistering model: { e } " )
8181
8282
8383async def main ():
84- discovery_service = None
85- model_endpoint = None
86-
87- try :
88- # Initialize discovery service
89- discovery_service = ModelServiceDiscovery (
90- host = SETTINGS .etcd_host , port = SETTINGS .etcd_port
91- )
92-
93- metadata = await get_metadata ()
94- model_endpoint = ModelEndpoint (
95- url = f"http://{ SETTINGS .host } :{ SETTINGS .port } " , metadata = metadata
96- )
84+ """Main entry point for model daemon."""
85+ logging .basicConfig (level = logging .INFO )
9786
98- # Setup signal handlers
99- loop = asyncio . get_running_loop ()
100- for sig in ( signal . SIGTERM , signal . SIGINT ):
101- loop . add_signal_handler ( sig , lambda : asyncio . create_task ( shutdown ()) )
87+ # Initialize discovery service
88+ discovery_service = ModelServiceDiscovery (
89+ host = SETTINGS . etcd_host , port = SETTINGS . etcd_port
90+ )
10291
103- # Run service
104- await run_service (discovery_service , model_endpoint )
92+ # Fetch metadata and create endpoint
93+ metadata = await get_metadata ()
94+ model_endpoint = ModelEndpoint (
95+ url = f"http://{ SETTINGS .host } :{ SETTINGS .port } " , metadata = metadata
96+ )
10597
106- except Exception as e :
107- logger .error (f"Failed to initialize model service: { e } " )
108- raise
98+ # Create service task
99+ service_task = asyncio .create_task (run_service (discovery_service , model_endpoint ))
109100
110-
111- async def shutdown ():
112- """Cleanup and shutdown"""
113- tasks = [t for t in asyncio .all_tasks () if t is not asyncio .current_task ()]
114- [task .cancel () for task in tasks ]
115- await asyncio .gather (* tasks , return_exceptions = True )
101+ # Setup signal handling
102+ stop_event = asyncio .Event ()
116103 loop = asyncio .get_running_loop ()
117- loop .stop ()
104+ for sig in (signal .SIGTERM , signal .SIGINT ):
105+ try :
106+ loop .add_signal_handler (sig , stop_event .set )
107+ except NotImplementedError :
108+ # Windows doesn't support add_signal_handler
109+ pass
110+
111+ # Wait for either shutdown signal or service completion
112+ wait_task = asyncio .create_task (stop_event .wait ())
113+
114+ done , _ = await asyncio .wait (
115+ {wait_task , service_task }, return_when = asyncio .FIRST_COMPLETED
116+ )
117+
118+ # Handle shutdown
119+ if wait_task in done :
120+ logger .info ("Stop signal received; shutting down daemon" )
121+ service_task .cancel ()
122+ try :
123+ await service_task
124+ except asyncio .CancelledError :
125+ pass
126+ else :
127+ # Service completed (possibly with error)
128+ wait_task .cancel ()
129+ await service_task # Re-raise any exception
118130
119131
120132if __name__ == "__main__" :
121- logging .basicConfig (level = logging .INFO )
122133 asyncio .run (main ())
0 commit comments