-
Notifications
You must be signed in to change notification settings - Fork 0
Speculative decoding #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: base-sha/0156696f3575079bb18d42841220aef5b85d54ef
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,72 +11,71 @@ | |||||
import requests | ||||||
import scrubadub | ||||||
import torch | ||||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||||||
|
||||||
from transformers import ( | ||||||
pipeline, | ||||||
AutoModelForCausalLM, | ||||||
AutoModelForSpeechSeq2Seq, | ||||||
AutoProcessor, | ||||||
) | ||||||
|
||||||
# Before we dig in, let's globally set up transformers | ||||||
# We will load up the model, etc now so we only need to | ||||||
# use the PIPE constant in the function. | ||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||||||
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3") | ||||||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||||||
model_id, | ||||||
torch_dtype=torch_dtype, | ||||||
low_cpu_mem_usage=True, | ||||||
use_safetensors=True, | ||||||
) | ||||||
model.to(device) | ||||||
processor = AutoProcessor.from_pretrained(model_id) | ||||||
PIPE = pipeline( | ||||||
"automatic-speech-recognition", | ||||||
model=model, | ||||||
tokenizer=processor.tokenizer, | ||||||
feature_extractor=processor.feature_extractor, | ||||||
max_new_tokens=128, | ||||||
torch_dtype=torch_dtype, | ||||||
device=device, | ||||||
) | ||||||
|
||||||
|
||||||
def transcribe_whispercpp(calljson, audiofile): | ||||||
"""Transcribes audio file using whisper.cpp. | ||||||
|
||||||
Args: | ||||||
calljson (dict): A dictionary containing the JSON data. | ||||||
audiofile (Path): The path to the audio file. | ||||||
|
||||||
Returns: | ||||||
dict: The updated calljson dictionary with the transcript. | ||||||
|
||||||
Explanation: | ||||||
This function sends the audio file to whisper.cpp for transcription. It constructs a multipart/form-data | ||||||
request with the audio file and other parameters. The response from whisper.cpp is parsed as JSON and | ||||||
merged into the calljson dictionary. The updated calljson dictionary is then returned. | ||||||
""" | ||||||
whisper_url = os.environ.get("TTT_WHISPERCPP_URL", "http://whisper:8080") | ||||||
|
||||||
# Now send the files over to whisper for transcribing | ||||||
files = { | ||||||
"file": (None, audiofile.read_bytes()), | ||||||
"temperature": (None, "0.0"), | ||||||
"temperature_inc": (None, "0.2"), | ||||||
"response_format": (None, "json"), | ||||||
} | ||||||
|
||||||
try: | ||||||
response = requests.post(f"{whisper_url}/inference", files=files) | ||||||
response.raise_for_status() | ||||||
except requests.exceptions.RequestException as e: | ||||||
print(f"A request error occurred while trying to post to whisper.cpp: {e}") | ||||||
raise RuntimeError( | ||||||
"A request error occurred while trying to post to whisper.cpp." | ||||||
) from e | ||||||
|
||||||
calltext = response.json() | ||||||
|
||||||
# And now merge that dict into calljson so [text] in calljson is the transcript | ||||||
calljson = {**calljson, **calltext} | ||||||
return calljson | ||||||
# If we set TTT_TRANSFORMERS_MODEL_ID, let's use that directly | ||||||
if os.environ.get("TTT_TRANSFORMERS_MODEL_ID", False): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code_clarification): The default value in Using
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||||||
print(f"We are using {torch_dtype} on {device}") | ||||||
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3") | ||||||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||||||
model_id, | ||||||
torch_dtype=torch_dtype, | ||||||
low_cpu_mem_usage=True, | ||||||
use_safetensors=True, | ||||||
) | ||||||
model.to(device) | ||||||
processor = AutoProcessor.from_pretrained(model_id) | ||||||
PIPE = pipeline( | ||||||
"automatic-speech-recognition", | ||||||
model=model, | ||||||
tokenizer=processor.tokenizer, | ||||||
feature_extractor=processor.feature_extractor, | ||||||
max_new_tokens=128, | ||||||
torch_dtype=torch_dtype, | ||||||
device=device, | ||||||
) | ||||||
# If we don't set a model, let's use the combo of best / fastest. | ||||||
# Speculative decoding with large-v3 / distil-v3 | ||||||
else: | ||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||||||
print(f"We are using {torch_dtype} on {device}") | ||||||
assistant_model_id = "distil-whisper/distil-large-v3" | ||||||
assistant_model = AutoModelForCausalLM.from_pretrained( | ||||||
assistant_model_id, | ||||||
torch_dtype=torch_dtype, | ||||||
low_cpu_mem_usage=True, | ||||||
use_safetensors=True, | ||||||
) | ||||||
assistant_model.to(device) | ||||||
model_id = "openai/whisper-large-v3" | ||||||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | ||||||
) | ||||||
model.to(device) | ||||||
processor = AutoProcessor.from_pretrained(model_id) | ||||||
PIPE = pipeline( | ||||||
"automatic-speech-recognition", | ||||||
model=model, | ||||||
tokenizer=processor.tokenizer, | ||||||
feature_extractor=processor.feature_extractor, | ||||||
max_new_tokens=128, | ||||||
generate_kwargs={"assistant_model": assistant_model}, | ||||||
torch_dtype=torch_dtype, | ||||||
device=device, | ||||||
) | ||||||
|
||||||
|
||||||
def transcribe_transformers(calljson, audiofile): | ||||||
|
@@ -103,58 +102,6 @@ def transcribe_transformers(calljson, audiofile): | |||||
return calljson | ||||||
|
||||||
|
||||||
def transcribe_deepgram(calljson, audiofile): | ||||||
"""Transcribes audio file using Deepgram API. | ||||||
|
||||||
Args: | ||||||
calljson (dict): A dictionary containing the JSON data. | ||||||
audiofile (Path): The path to the audio file. | ||||||
|
||||||
Returns: | ||||||
dict: The updated calljson dictionary with the transcript. | ||||||
|
||||||
Explanation: | ||||||
This function sends the audio file to the Deepgram API for transcription. It constructs a POST request | ||||||
with the audio file and necessary headers. The response from Deepgram is parsed as JSON, and the | ||||||
transcript is extracted and added to the calljson dictionary. The updated calljson dictionary is then | ||||||
returned. | ||||||
""" | ||||||
deepgram_key = os.environ.get("TTT_DEEPGRAM_KEY") | ||||||
headers = { | ||||||
"Authorization": f"Token {deepgram_key}", | ||||||
"Content-Type": "audio/wav", | ||||||
} | ||||||
params = { | ||||||
"model": "nova-2-phonecall", | ||||||
"language": "en-US", | ||||||
"smart_format": "true", | ||||||
} | ||||||
|
||||||
data = audiofile.read_bytes() | ||||||
try: | ||||||
response = requests.post( | ||||||
"https://api.deepgram.com/v1/listen", | ||||||
params=params, | ||||||
headers=headers, | ||||||
data=data, | ||||||
) | ||||||
response.raise_for_status() | ||||||
except requests.exceptions.RequestException as e: | ||||||
print(f"A request error occurred while trying to post to Deepgram: {e}") | ||||||
raise RuntimeError( | ||||||
"A request error occurred while trying to post to Deepgram." | ||||||
) from e | ||||||
|
||||||
json = response.json() | ||||||
|
||||||
# We take the json returned from deepgram and pull out the "transcript" | ||||||
# then tack it onto the calljson dict as "text" which is what whisper | ||||||
# normally uses | ||||||
calltext = json["results"]["channels"][0]["alternatives"][0]["transcript"] | ||||||
calljson["text"] = calltext | ||||||
return calljson | ||||||
|
||||||
|
||||||
def send_notifications(calljson, audiofile, destinations): | ||||||
""" | ||||||
Sends notifications using the provided calljson, audiofile, and destinations. | ||||||
|
@@ -387,5 +334,98 @@ def main(): | |||||
print(f"No permission to delete {audiofile}.") | ||||||
|
||||||
|
||||||
def transcribe_whispercpp(calljson, audiofile): | ||||||
"""Transcribes audio file using whisper.cpp. | ||||||
|
||||||
Args: | ||||||
calljson (dict): A dictionary containing the JSON data. | ||||||
audiofile (Path): The path to the audio file. | ||||||
|
||||||
Returns: | ||||||
dict: The updated calljson dictionary with the transcript. | ||||||
|
||||||
Explanation: | ||||||
This function sends the audio file to whisper.cpp for transcription. It constructs a multipart/form-data | ||||||
request with the audio file and other parameters. The response from whisper.cpp is parsed as JSON and | ||||||
merged into the calljson dictionary. The updated calljson dictionary is then returned. | ||||||
""" | ||||||
whisper_url = os.environ.get("TTT_WHISPERCPP_URL", "http://whisper:8080") | ||||||
|
||||||
# Now send the files over to whisper for transcribing | ||||||
files = { | ||||||
"file": (None, audiofile.read_bytes()), | ||||||
"temperature": (None, "0.0"), | ||||||
"temperature_inc": (None, "0.2"), | ||||||
"response_format": (None, "json"), | ||||||
} | ||||||
|
||||||
try: | ||||||
response = requests.post(f"{whisper_url}/inference", files=files) | ||||||
response.raise_for_status() | ||||||
except requests.exceptions.RequestException as e: | ||||||
print(f"A request error occurred while trying to post to whisper.cpp: {e}") | ||||||
raise RuntimeError( | ||||||
"A request error occurred while trying to post to whisper.cpp." | ||||||
) from e | ||||||
|
||||||
calltext = response.json() | ||||||
|
||||||
# And now merge that dict into calljson so [text] in calljson is the transcript | ||||||
calljson = {**calljson, **calltext} | ||||||
return calljson | ||||||
|
||||||
|
||||||
def transcribe_deepgram(calljson, audiofile): | ||||||
"""Transcribes audio file using Deepgram API. | ||||||
|
||||||
Args: | ||||||
calljson (dict): A dictionary containing the JSON data. | ||||||
audiofile (Path): The path to the audio file. | ||||||
|
||||||
Returns: | ||||||
dict: The updated calljson dictionary with the transcript. | ||||||
|
||||||
Explanation: | ||||||
This function sends the audio file to the Deepgram API for transcription. It constructs a POST request | ||||||
with the audio file and necessary headers. The response from Deepgram is parsed as JSON, and the | ||||||
transcript is extracted and added to the calljson dictionary. The updated calljson dictionary is then | ||||||
returned. | ||||||
""" | ||||||
deepgram_key = os.environ.get("TTT_DEEPGRAM_KEY") | ||||||
headers = { | ||||||
"Authorization": f"Token {deepgram_key}", | ||||||
"Content-Type": "audio/wav", | ||||||
} | ||||||
params = { | ||||||
"model": "nova-2-phonecall", | ||||||
"language": "en-US", | ||||||
"smart_format": "true", | ||||||
} | ||||||
|
||||||
data = audiofile.read_bytes() | ||||||
try: | ||||||
response = requests.post( | ||||||
"https://api.deepgram.com/v1/listen", | ||||||
params=params, | ||||||
headers=headers, | ||||||
data=data, | ||||||
) | ||||||
response.raise_for_status() | ||||||
except requests.exceptions.RequestException as e: | ||||||
print(f"A request error occurred while trying to post to Deepgram: {e}") | ||||||
raise RuntimeError( | ||||||
"A request error occurred while trying to post to Deepgram." | ||||||
) from e | ||||||
|
||||||
json = response.json() | ||||||
|
||||||
# We take the json returned from deepgram and pull out the "transcript" | ||||||
# then tack it onto the calljson dict as "text" which is what whisper | ||||||
# normally uses | ||||||
calltext = json["results"]["channels"][0]["alternatives"][0]["transcript"] | ||||||
calljson["text"] = calltext | ||||||
return calljson | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code_refinement): Consider consolidating imports for clarity and maintainability.
Grouping related imports in a single block can improve readability and make the codebase easier to manage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment helpful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment type correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment area correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What type of LLM test could this comment become?