Skip to content

Commit 6244346

Browse files
committed
improve diarization object remove jupyter
1 parent 8dabef0 commit 6244346

File tree

1 file changed

+54
-10
lines changed
  • modules/jax/testcontainers/whisper_cuda/whisper_diarization

1 file changed

+54
-10
lines changed

modules/jax/testcontainers/whisper_cuda/whisper_diarization/__init__.py

+54-10
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option
3535
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
3636
self.model_name = model_name
3737
self.hf_token = hf_token
38-
self.with_exposed_ports(8888) # Expose Jupyter notebook port
3938
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
4039
self.with_env("CUDA_VISIBLE_DEVICES", "all")
4140
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
41+
self.start_timeout = 600 # 10 minutes
42+
self.connection_retries = 5
43+
self.connection_retry_delay = 10 # seconds
4244

4345
# Install required dependencies
4446
self.with_command("sh -c '"
4547
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
4648
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && "
47-
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
48-
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
49+
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
4950
"'")
5051

5152
@wait_container_is_ready(URLError)
5253
def _connect(self):
53-
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
54-
res = urllib.request.urlopen(url)
55-
if res.status != 200:
56-
raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}")
54+
for attempt in range(self.connection_retries):
55+
try:
56+
# Check if JAX and other required libraries are properly installed and functioning
57+
result = self.run_command(
58+
"import jax; import whisper_jax; import pyannote.audio; "
59+
"print(f'JAX version: {jax.__version__}'); "
60+
"print(f'Whisper-JAX version: {whisper_jax.__version__}'); "
61+
"print(f'Pyannote Audio version: {pyannote.audio.__version__}'); "
62+
"print(f'Available devices: {jax.devices()}'); "
63+
"print(jax.numpy.add(1, 1))"
64+
)
65+
66+
if "JAX version" in result.output.decode() and "Available devices" in result.output.decode():
67+
logging.info(f"JAX-Whisper-Diarization environment verified:\n{result.output.decode()}")
68+
return True
69+
else:
70+
raise Exception("JAX-Whisper-Diarization environment check failed")
71+
72+
except Exception as e:
73+
if attempt < self.connection_retries - 1:
74+
logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...")
75+
time.sleep(self.connection_retry_delay)
76+
else:
77+
raise Exception(f"Failed to connect to JAX-Whisper-Diarization container after {self.connection_retries} attempts: {str(e)}")
78+
79+
return False
5780

5881
def connect(self):
5982
"""
6083
Connect to the JAX-Whisper-Diarization container and ensure it's ready.
84+
This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning.
85+
It also checks for available devices, including GPUs if applicable.
6186
"""
6287
self._connect()
63-
logging.info("Successfully connected to JAX-Whisper-Diarization container")
88+
logging.info("Successfully connected to JAX-Whisper-Diarization container and verified the environment")
6489

6590
def run_command(self, command: str):
6691
"""
@@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True):
242267

243268
def start(self):
244269
"""
245-
Start the JAX-Whisper-Diarization container.
270+
Start the JAX-Whisper-Diarization container and wait for it to be ready.
246271
"""
247272
super().start()
248-
logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
273+
self._wait_for_container_to_be_ready()
274+
logging.info("JAX-Whisper-Diarization container started and ready.")
249275
return self
276+
277+
def _wait_for_container_to_be_ready(self):
278+
# Wait for a specific log message that indicates the container is ready
279+
self.wait_for_logs("Installation completed")
280+
281+
def stop(self, force=True):
282+
"""
283+
Stop the JAX-Whisper-Diarization container.
284+
"""
285+
super().stop(force)
286+
logging.info("JAX-Whisper-Diarization container stopped.")
287+
288+
@property
289+
def timeout(self):
290+
"""
291+
Get the container start timeout.
292+
"""
293+
return self.start_timeout

0 commit comments

Comments
 (0)