@@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option
35
35
super ().__init__ ("nvcr.io/nvidia/jax:23.08-py3" , ** kwargs )
36
36
self .model_name = model_name
37
37
self .hf_token = hf_token
38
- self .with_exposed_ports (8888 ) # Expose Jupyter notebook port
39
38
self .with_env ("NVIDIA_VISIBLE_DEVICES" , "all" )
40
39
self .with_env ("CUDA_VISIBLE_DEVICES" , "all" )
41
40
self .with_kwargs (runtime = "nvidia" ) # Use NVIDIA runtime for GPU support
41
+ self .start_timeout = 600 # 10 minutes
42
+ self .connection_retries = 5
43
+ self .connection_retry_delay = 10 # seconds
42
44
43
45
# Install required dependencies
44
46
self .with_command ("sh -c '"
45
47
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
46
48
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && "
47
- "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
48
- "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
49
+ "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
49
50
"'" )
50
51
51
52
@wait_container_is_ready (URLError )
52
53
def _connect (self ):
53
- url = f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
54
- res = urllib .request .urlopen (url )
55
- if res .status != 200 :
56
- raise Exception (f"Failed to connect to JAX-Whisper-Diarization container. Status: { res .status } " )
54
+ for attempt in range (self .connection_retries ):
55
+ try :
56
+ # Check if JAX and other required libraries are properly installed and functioning
57
+ result = self .run_command (
58
+ "import jax; import whisper_jax; import pyannote.audio; "
59
+ "print(f'JAX version: {jax.__version__}'); "
60
+ "print(f'Whisper-JAX version: {whisper_jax.__version__}'); "
61
+ "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); "
62
+ "print(f'Available devices: {jax.devices()}'); "
63
+ "print(jax.numpy.add(1, 1))"
64
+ )
65
+
66
+ if "JAX version" in result .output .decode () and "Available devices" in result .output .decode ():
67
+ logging .info (f"JAX-Whisper-Diarization environment verified:\n { result .output .decode ()} " )
68
+ return True
69
+ else :
70
+ raise Exception ("JAX-Whisper-Diarization environment check failed" )
71
+
72
+ except Exception as e :
73
+ if attempt < self .connection_retries - 1 :
74
+ logging .warning (f"Connection attempt { attempt + 1 } failed. Retrying in { self .connection_retry_delay } seconds..." )
75
+ time .sleep (self .connection_retry_delay )
76
+ else :
77
+ raise Exception (f"Failed to connect to JAX-Whisper-Diarization container after { self .connection_retries } attempts: { str (e )} " )
78
+
79
+ return False
57
80
58
81
def connect (self ):
59
82
"""
60
83
Connect to the JAX-Whisper-Diarization container and ensure it's ready.
84
+ This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning.
85
+ It also checks for available devices, including GPUs if applicable.
61
86
"""
62
87
self ._connect ()
63
- logging .info ("Successfully connected to JAX-Whisper-Diarization container" )
88
+ logging .info ("Successfully connected to JAX-Whisper-Diarization container and verified the environment " )
64
89
65
90
def run_command (self , command : str ):
66
91
"""
@@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True):
242
267
243
268
def start (self ):
244
269
"""
245
- Start the JAX-Whisper-Diarization container.
270
+ Start the JAX-Whisper-Diarization container and wait for it to be ready .
246
271
"""
247
272
super ().start ()
248
- logging .info (f"JAX-Whisper-Diarization container started. Jupyter URL: http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} " )
273
+ self ._wait_for_container_to_be_ready ()
274
+ logging .info ("JAX-Whisper-Diarization container started and ready." )
249
275
return self
276
+
277
+ def _wait_for_container_to_be_ready (self ):
278
+ # Wait for a specific log message that indicates the container is ready
279
+ self .wait_for_logs ("Installation completed" )
280
+
281
+ def stop (self , force = True ):
282
+ """
283
+ Stop the JAX-Whisper-Diarization container.
284
+ """
285
+ super ().stop (force )
286
+ logging .info ("JAX-Whisper-Diarization container stopped." )
287
+
288
+ @property
289
+ def timeout (self ):
290
+ """
291
+ Get the container start timeout.
292
+ """
293
+ return self .start_timeout
0 commit comments