Skip to content

Commit 621401f

Browse files
committedAug 29, 2023
Initial commit
0 parents  commit 621401f

6 files changed

+197
-0
lines changed
 

‎Dockerfile

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Must use a Cuda version 11+
2+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3+
4+
ARG HF_TOKEN
5+
6+
WORKDIR /
7+
8+
# Install git and wget
9+
RUN apt-get update && apt-get install -y git wget
10+
11+
# Upgrade pip
12+
RUN pip install --upgrade pip
13+
14+
ADD requirements.txt requirements.txt
15+
RUN pip3 install -r requirements.txt
16+
17+
ENV HF_TOKEN=${HF_TOKEN}
18+
19+
# Add your model weight files
20+
ADD download.py .
21+
RUN python3 download.py
22+
23+
ADD . .
24+
25+
EXPOSE 8000
26+
27+
CMD python3 -u app.py

‎README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
![](https://www.banana.dev/lib_zOkYpJoyYVcAamDf/x2p804nk9qvjb1vg.svg?w=340 "Banana.dev")
2+
3+
# Banana.dev bark starter template
4+
5+
This is a segmentation starter template from [Banana.dev](https://www.banana.dev) that allows on-demand serverless GPU inference.
6+
7+
You can fork this repository and deploy it on Banana as is, or customize it based on your own needs.
8+
9+
10+
# Running this app
11+
12+
## Deploying on Banana.dev
13+
14+
1. [Fork this](https://github.com/bananaml/demo-segmentation/fork) repository to your own GitHub account.
15+
2. Connect your GitHub account on Banana.
16+
3. [Create a new model](https://app.banana.dev/deploy) on Banana from the forked GitHub repository.
17+
18+
## Running after deploying
19+
20+
1. Wait for the model to build after creating it.
21+
2. Make an API request using one of the provided snippets in your Banana dashboard. However, instead of sending a prompt as provided in the snippet, adjust the prompt to fit the needs of the segmentation model:
22+
23+
```python
24+
inputs = {
25+
"audio": "bucket_link_to_wav_file",
26+
"option": "voice_activity_detection"
27+
}
28+
```
29+
30+
The `audio` parameter should be substituted with your S3 (or any other provider where you can store .wav files) bucket link that contains the .wav audio file you want to segment. For the `option` parameter, you have to choose between the following options depending on what segmentation information you want to gain from the audio file:
31+
32+
* voice_activity_detection
33+
* overlapped_speech_detection
34+
* instantaneous_speaker_counting
35+
* speaker_change_detection
36+
37+
In the example above, we chose `voice_activity_detection` as an option.
38+
39+
For more info, check out the [Banana.dev docs](https://docs.banana.dev/banana-docs/).

‎app.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from potassium import Potassium, Request, Response
2+
from pyannote.audio import Model, Inference
3+
import numpy as np
4+
import os
5+
import base64
6+
import matplotlib.pyplot as plt
7+
import io
8+
import requests
9+
from pyannote.core import notebook
10+
11+
MODEL = "pyannote/segmentation"
12+
HF_TOKEN = os.getenv('HF_TOKEN')
13+
SPEAKER_AXIS = 2
14+
TIME_AXIS = 1
15+
16+
app = Potassium("segmentation")
17+
18+
@app.init
19+
def init():
20+
"""
21+
Initialize the application with the pretrained model.
22+
"""
23+
model = Model.from_pretrained(MODEL, use_auth_token=HF_TOKEN)
24+
context = {
25+
"model": model
26+
}
27+
return context
28+
29+
def process_audio(model, hook, label, filename):
30+
"""
31+
Process the audio file with the given model, hook, label, and filename.
32+
"""
33+
inference = Inference(model, pre_aggregation_hook=hook)
34+
prob = inference(filename)
35+
prob.labels = [label]
36+
fig, ax = plt.subplots()
37+
notebook.plot_feature(prob, ax=ax)
38+
buf = io.BytesIO()
39+
plt.savefig(buf, format='png')
40+
plt.close(fig)
41+
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
42+
return image_base64
43+
44+
def voice_activity_detection(model):
45+
to_vad = lambda o: np.max(o, axis=SPEAKER_AXIS, keepdims=True)
46+
result = process_audio(model, to_vad, 'SPEECH', "/tmp/temp.wav")
47+
return result
48+
49+
def overlapped_speech_detection(model):
50+
to_osd = lambda o: np.partition(o, -2, axis=SPEAKER_AXIS)[:, :, -2, np.newaxis]
51+
result = process_audio(model, to_osd, 'OVERLAP', "/tmp/temp.wav")
52+
return result
53+
54+
def instantaneous_speaker_counting(model):
55+
to_cnt = lambda probability: np.sum(probability, axis=SPEAKER_AXIS, keepdims=True)
56+
result = process_audio(model, to_cnt, 'SPEAKER_COUNT', "/tmp/temp.wav")
57+
return result
58+
59+
def speaker_change_detection(model):
60+
to_scd = lambda probability: np.max(
61+
np.abs(np.diff(probability, n=1, axis=TIME_AXIS)),
62+
axis=SPEAKER_AXIS, keepdims=True)
63+
result = process_audio(model, to_scd, 'SPEAKER_CHANGE', "/tmp/temp.wav")
64+
return result
65+
66+
@app.handler()
67+
def handler(context: dict, request: Request) -> Response:
68+
"""
69+
Handle the incoming request and return the response.
70+
"""
71+
model = context.get("model")
72+
audio_input = request.json.get("audio")
73+
option = request.json.get("option")
74+
response = requests.get(audio_input)
75+
filename = '/tmp/temp.wav'
76+
with open(filename, 'wb') as f:
77+
f.write(response.content)
78+
if option == "voice_activity_detection":
79+
image_base64 = voice_activity_detection(model)
80+
elif option == "overlapped_speech_detection":
81+
image_base64 = overlapped_speech_detection(model)
82+
elif option == "instantaneous_speaker_counting":
83+
image_base64 = instantaneous_speaker_counting(model)
84+
elif option == "speaker_change_detection":
85+
image_base64 = speaker_change_detection(model)
86+
return Response(json={"output": image_base64}, status=200)
87+
88+
if __name__ == "__main__":
89+
app.serve()

‎banana_config.json

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"version": "1",
3+
"inputs": [
4+
{
5+
"name": "option",
6+
"description": "Choose a segmentation to perform on the audio file",
7+
"type": "options",
8+
"options": ["voice_activity_detection","overlapped_speech_detection", "instantaneous_speaker_counting", "speaker_change_detection"],
9+
"required": true,
10+
"default": "voice_activity_detection"
11+
},
12+
{
13+
"name": "audio",
14+
"type": "file",
15+
"mimes": ["audio/wav"],
16+
"required": true
17+
}
18+
],
19+
"output": {
20+
"image": {
21+
"type": "image",
22+
"source": "base64",
23+
"path": "output"
24+
}
25+
}
26+
}
27+

‎download.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from pyannote.audio import Model
2+
import os
3+
4+
HF_TOKEN = os.getenv('HF_TOKEN')
5+
6+
def download_model():
7+
"""Load model"""
8+
model = Model.from_pretrained("pyannote/segmentation", use_auth_token=HF_TOKEN)
9+
10+
if __name__ == "__main__":
11+
download_model()

‎requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
potassium>=0.0.9
2+
torch
3+
transformers
4+
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip

0 commit comments

Comments
 (0)
Please sign in to comment.