Skip to content

Commit

Permalink
Improve device detection
Browse files Browse the repository at this point in the history
  • Loading branch information
ertrzyiks committed Jan 21, 2025
1 parent b63fab2 commit be416be
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
77 changes: 56 additions & 21 deletions src/components/search/search.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { MagnifyingGlassIcon, MicrophoneIcon } from "@heroicons/react/20/solid";

import { useWhisperWorker } from "./use-whisper-worker";
import { useAudioInput } from "./use-audio-input";
import { use } from "chai";

const queryClient = new QueryClient();

Expand All @@ -36,6 +35,36 @@ function useSearch({ query }: { query: string }) {

const WHISPER_SAMPLING_RATE = 16_000;

async function isWebGPUSupported() {
if (!("gpu" in navigator)) {
return {
supported: false,
reason: "WebGPU is not supported on this browser.",
};
}

try {
// Attempt to request an adapter for WebGPU
const adapter = await (navigator as any).gpu.requestAdapter();
if (!adapter) {
return {
supported: false,
reason:
"WebGPU is not available due to platform restrictions or lack of hardware support.",
};
}

// If we get an adapter, return success
return { supported: true, reason: "WebGPU is supported and enabled." };
} catch (error) {
// If any error occurs, handle it gracefully
return {
supported: false,
reason: `WebGPU initialization failed: ${(error as Error).message}`,
};
}
}

interface Props {
query: string;
onChange: (value: string) => void;
Expand All @@ -54,7 +83,9 @@ function SearchForm({
results = [],
}: Props) {
const searchInputRef = useRef<HTMLInputElement>(null);
const [isWebGPUAvailable, setIsWebGPUAvailable] = useState(false);
const [isWebGPUAvailable, setIsWebGPUAvailable] = useState<boolean | null>(
null,
);
const { startRecording, blob } = useAudioInput();
const { processAudio, loadModels, status, text } = useWhisperWorker();

Expand All @@ -65,14 +96,18 @@ function SearchForm({
}, [status]);

useEffect(() => {
if (blob && status === "ready") {
if (blob && status === "ready" && isWebGPUAvailable !== null) {
const audioContext = new AudioContext({
sampleRate: WHISPER_SAMPLING_RATE,
});

processAudio(blob, audioContext);
processAudio(
blob,
audioContext,
isWebGPUAvailable ? "webgpu" : undefined,
);
}
}, [blob, status]);
}, [blob, status, isWebGPUAvailable]);

const handleChange = (input: { value: string }) => {
if (input) {
Expand All @@ -82,7 +117,9 @@ function SearchForm({
};

useEffect(() => {
setIsWebGPUAvailable(!!(navigator as any).gpu);
isWebGPUSupported().then(({ supported }) => {
setIsWebGPUAvailable(supported);
});
}, []);

useEffect(() => {
Expand All @@ -96,21 +133,19 @@ function SearchForm({
return (
<div className="top-16 w-full">
<div className="w-full flex" suppressHydrationWarning>
{isWebGPUAvailable ? (
<button
onClick={() => startRecording()}
className={[
"border-transparent",
status === "ready" ? "border-b-green-400" : "",
"border rounded-full mr-2",
].join(" ")}
>
<MicrophoneIcon
className="h-5 w-5 text-inherit"
aria-label="Use microphone to dictate search query"
/>
</button>
) : null}
<button
onClick={() => startRecording()}
className={[
"border-transparent",
status === "ready" ? "border-b-green-300" : "",
"border rounded-full mr-2",
].join(" ")}
>
<MicrophoneIcon
className="h-5 w-5 text-inherit"
aria-label="Use microphone to dictate search query"
/>
</button>

<div className="flex-grow">
<Combobox<null | { value: string }>
Expand Down
4 changes: 2 additions & 2 deletions src/components/search/use-whisper-worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export function useWhisperWorker() {
};
}, []);

const processAudio = async (blob, audioContext) => {
const processAudio = async (blob, audioContext, device) => {
const fileReader = new FileReader();

fileReader.onloadend = async () => {
Expand All @@ -120,7 +120,7 @@ export function useWhisperWorker() {

worker.current.postMessage({
type: "generate",
data: { audio, language },
data: { audio, language, device },
});
};
fileReader.readAsArrayBuffer(blob);
Expand Down
9 changes: 5 additions & 4 deletions src/components/search/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
WhisperForConditionalGeneration,
TextStreamer,
full,
env,
} from "@huggingface/transformers";

const MAX_NEW_TOKENS = 64;
Expand All @@ -19,7 +20,7 @@ class AutomaticSpeechRecognitionPipeline {
static processor = null;
static model = null;

static async getInstance(progress_callback = null) {
static async getInstance(progress_callback = null, { device } = {}) {
this.model_id = "onnx-community/whisper-base";

this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
Expand All @@ -36,7 +37,7 @@ class AutomaticSpeechRecognitionPipeline {
encoder_model: "fp32", // 'fp16' works too
decoder_model_merged: "q4", // or 'fp32' ('fp16' is broken)
},
device: "webgpu",
device,
progress_callback,
},
);
Expand All @@ -46,7 +47,7 @@ class AutomaticSpeechRecognitionPipeline {
}

let processing = false;
async function generate({ audio, language }) {
async function generate({ audio, language, device }) {
if (processing) return;
processing = true;

Expand All @@ -55,7 +56,7 @@ async function generate({ audio, language }) {

// Retrieve the text-generation pipeline.
const [tokenizer, processor, model] =
await AutomaticSpeechRecognitionPipeline.getInstance();
await AutomaticSpeechRecognitionPipeline.getInstance(null, { device });

let startTime;
let numTokens = 0;
Expand Down

0 comments on commit be416be

Please sign in to comment.