Skip to content

Commit

Permalink
Merge branch 'main' into cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jquagga committed Oct 15, 2024
2 parents 113d5f7 + 4ffb81c commit b48e745
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ channels:
dependencies:
- python=3.12
- ffmpeg
- gcc
- pip
- pip:
- --index-url https://download.pytorch.org/whl/cpu
Expand Down
19 changes: 4 additions & 15 deletions ttt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import requests
import torch
from better_profanity import profanity
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
Expand All @@ -25,7 +24,7 @@
# Before we dig in, let's globally set up transformers
# We will load up the model, etc now so we only need to
# use the PIPE constant in the function.
torch.set_float32_matmul_precision("high")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3-turbo")
Expand All @@ -37,12 +36,6 @@
use_safetensors=True,
)
model.to(device)

# Enable static cache and compile the forward pass
model.generation_config.cache_implementation = "static"
model.generation_config.max_new_tokens = 256
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

processor = AutoProcessor.from_pretrained(model_id)
PIPE = pipeline(
"automatic-speech-recognition",
Expand Down Expand Up @@ -72,13 +65,9 @@ def transcribe_transformers(calljson, audiofile):

audiofile = str(audiofile)

# Set the return argument to english & return timestamps to support
# calls over 30 seconds.
with sdpa_kernel(SDPBackend.MATH):
result = PIPE(
audiofile,
generate_kwargs={"return_timestamps": True},
)
# Set the return argument to english
# result = PIPE(audiofile, generate_kwargs={"language": "english"})
result = PIPE(audiofile, generate_kwargs={"return_timestamps": True})
calljson["text"] = result["text"]
return calljson

Expand Down

0 comments on commit b48e745

Please sign in to comment.