Skip to content

Speculative decoding #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: base-sha/0156696f3575079bb18d42841220aef5b85d54ef
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 153 additions & 113 deletions ttt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,72 +11,71 @@
import requests
import scrubadub
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

from transformers import (
pipeline,
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
)
Comment on lines +15 to +20

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_refinement): Consider consolidating imports for clarity and maintainability.

Grouping related imports in a single block can improve readability and make the codebase easier to manage.

Suggested change
from transformers import (
pipeline,
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
)
from transformers import (
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test


# Before we dig in, let's globally set up transformers
# We will load up the model, etc now so we only need to
# use the PIPE constant in the function.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
PIPE = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)


def transcribe_whispercpp(calljson, audiofile):
"""Transcribes audio file using whisper.cpp.

Args:
calljson (dict): A dictionary containing the JSON data.
audiofile (Path): The path to the audio file.

Returns:
dict: The updated calljson dictionary with the transcript.

Explanation:
This function sends the audio file to whisper.cpp for transcription. It constructs a multipart/form-data
request with the audio file and other parameters. The response from whisper.cpp is parsed as JSON and
merged into the calljson dictionary. The updated calljson dictionary is then returned.
"""
whisper_url = os.environ.get("TTT_WHISPERCPP_URL", "http://whisper:8080")

# Now send the files over to whisper for transcribing
files = {
"file": (None, audiofile.read_bytes()),
"temperature": (None, "0.0"),
"temperature_inc": (None, "0.2"),
"response_format": (None, "json"),
}

try:
response = requests.post(f"{whisper_url}/inference", files=files)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"A request error occurred while trying to post to whisper.cpp: {e}")
raise RuntimeError(
"A request error occurred while trying to post to whisper.cpp."
) from e

calltext = response.json()

# And now merge that dict into calljson so [text] in calljson is the transcript
calljson = {**calljson, **calltext}
return calljson
# If we set TTT_TRANSFORMERS_MODEL_ID, let's use that directly
if os.environ.get("TTT_TRANSFORMERS_MODEL_ID", False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_clarification): The default value in os.environ.get should be consistent with expected data types.

Using False as a default for an environment variable that is expected to be a string can be misleading. Consider using None or a default string value.

Suggested change
if os.environ.get("TTT_TRANSFORMERS_MODEL_ID", False):
if os.environ.get("TTT_TRANSFORMERS_MODEL_ID", None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"We are using {torch_dtype} on {device}")
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
PIPE = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
# If we don't set a model, let's use the combo of best / fastest.
# Speculative decoding with large-v3 / distil-v3
else:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"We are using {torch_dtype} on {device}")
assistant_model_id = "distil-whisper/distil-large-v3"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
PIPE = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)


def transcribe_transformers(calljson, audiofile):
Expand All @@ -103,58 +102,6 @@ def transcribe_transformers(calljson, audiofile):
return calljson


def transcribe_deepgram(calljson, audiofile):
"""Transcribes audio file using Deepgram API.

Args:
calljson (dict): A dictionary containing the JSON data.
audiofile (Path): The path to the audio file.

Returns:
dict: The updated calljson dictionary with the transcript.

Explanation:
This function sends the audio file to the Deepgram API for transcription. It constructs a POST request
with the audio file and necessary headers. The response from Deepgram is parsed as JSON, and the
transcript is extracted and added to the calljson dictionary. The updated calljson dictionary is then
returned.
"""
deepgram_key = os.environ.get("TTT_DEEPGRAM_KEY")
headers = {
"Authorization": f"Token {deepgram_key}",
"Content-Type": "audio/wav",
}
params = {
"model": "nova-2-phonecall",
"language": "en-US",
"smart_format": "true",
}

data = audiofile.read_bytes()
try:
response = requests.post(
"https://api.deepgram.com/v1/listen",
params=params,
headers=headers,
data=data,
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"A request error occurred while trying to post to Deepgram: {e}")
raise RuntimeError(
"A request error occurred while trying to post to Deepgram."
) from e

json = response.json()

# We take the json returned from deepgram and pull out the "transcript"
# then tack it onto the calljson dict as "text" which is what whisper
# normally uses
calltext = json["results"]["channels"][0]["alternatives"][0]["transcript"]
calljson["text"] = calltext
return calljson


def send_notifications(calljson, audiofile, destinations):
"""
Sends notifications using the provided calljson, audiofile, and destinations.
Expand Down Expand Up @@ -387,5 +334,98 @@ def main():
print(f"No permission to delete {audiofile}.")


def transcribe_whispercpp(calljson, audiofile):
"""Transcribes audio file using whisper.cpp.

Args:
calljson (dict): A dictionary containing the JSON data.
audiofile (Path): The path to the audio file.

Returns:
dict: The updated calljson dictionary with the transcript.

Explanation:
This function sends the audio file to whisper.cpp for transcription. It constructs a multipart/form-data
request with the audio file and other parameters. The response from whisper.cpp is parsed as JSON and
merged into the calljson dictionary. The updated calljson dictionary is then returned.
"""
whisper_url = os.environ.get("TTT_WHISPERCPP_URL", "http://whisper:8080")

# Now send the files over to whisper for transcribing
files = {
"file": (None, audiofile.read_bytes()),
"temperature": (None, "0.0"),
"temperature_inc": (None, "0.2"),
"response_format": (None, "json"),
}

try:
response = requests.post(f"{whisper_url}/inference", files=files)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"A request error occurred while trying to post to whisper.cpp: {e}")
raise RuntimeError(
"A request error occurred while trying to post to whisper.cpp."
) from e

calltext = response.json()

# And now merge that dict into calljson so [text] in calljson is the transcript
calljson = {**calljson, **calltext}
return calljson


def transcribe_deepgram(calljson, audiofile):
"""Transcribes audio file using Deepgram API.

Args:
calljson (dict): A dictionary containing the JSON data.
audiofile (Path): The path to the audio file.

Returns:
dict: The updated calljson dictionary with the transcript.

Explanation:
This function sends the audio file to the Deepgram API for transcription. It constructs a POST request
with the audio file and necessary headers. The response from Deepgram is parsed as JSON, and the
transcript is extracted and added to the calljson dictionary. The updated calljson dictionary is then
returned.
"""
deepgram_key = os.environ.get("TTT_DEEPGRAM_KEY")
headers = {
"Authorization": f"Token {deepgram_key}",
"Content-Type": "audio/wav",
}
params = {
"model": "nova-2-phonecall",
"language": "en-US",
"smart_format": "true",
}

data = audiofile.read_bytes()
try:
response = requests.post(
"https://api.deepgram.com/v1/listen",
params=params,
headers=headers,
data=data,
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"A request error occurred while trying to post to Deepgram: {e}")
raise RuntimeError(
"A request error occurred while trying to post to Deepgram."
) from e

json = response.json()

# We take the json returned from deepgram and pull out the "transcript"
# then tack it onto the calljson dict as "text" which is what whisper
# normally uses
calltext = json["results"]["channels"][0]["alternatives"][0]["transcript"]
calljson["text"] = calltext
return calljson


if __name__ == "__main__":
main()