Skip to content

Commit 2d16dc9

Browse files
authored
[webgpu] Check if runtime support WebGPU before initial a WebGPU backend (#5218)
1 parent 8a1fd30 commit 2d16dc9

File tree

3 files changed

+51
-36
lines changed

3 files changed

+51
-36
lines changed

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ export interface WebGPUTimingInfo extends TimingInfo {
7373
const CPU_HANDOFF_SIZE_THRESHOLD =
7474
env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
7575

76-
const DEFAULT_GPUBUFFER_USAGE =
77-
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
78-
7976
export class WebGPUBackend extends KernelBackend {
8077
device: GPUDevice;
8178
queue: GPUQueue;
@@ -109,6 +106,9 @@ export class WebGPUBackend extends KernelBackend {
109106

110107
constructor(device: GPUDevice, glslang: Glslang, supportTimeQuery = false) {
111108
super();
109+
if (!webgpu_util.isWebGPUSupported()) {
110+
throw new Error('WebGPU is not supported on this device');
111+
}
112112
this.layoutCache = {};
113113
this.pipelineCache = {};
114114
this.device = device;
@@ -133,6 +133,11 @@ export class WebGPUBackend extends KernelBackend {
133133
return 32;
134134
}
135135

136+
defaultGpuBufferUsage(): number {
137+
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC |
138+
GPUBufferUsage.COPY_DST;
139+
}
140+
136141
flushDisposalQueue() {
137142
this.tensorDisposalQueue.forEach(d => {
138143
this.maybeReleaseBuffer(d);
@@ -191,7 +196,8 @@ export class WebGPUBackend extends KernelBackend {
191196
}
192197

193198
acquireBuffer(
194-
byteSize: number, usage: GPUBufferUsageFlags = DEFAULT_GPUBUFFER_USAGE) {
199+
byteSize: number,
200+
usage: GPUBufferUsageFlags = this.defaultGpuBufferUsage()) {
195201
return this.bufferManager.acquireBuffer(byteSize, usage);
196202
}
197203

@@ -248,7 +254,7 @@ export class WebGPUBackend extends KernelBackend {
248254
this.tensorMap.set(dataId, {
249255
dtype,
250256
values,
251-
bufferInfo: {byteSize, usage: DEFAULT_GPUBUFFER_USAGE},
257+
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
252258
refCount: 1
253259
});
254260
return dataId;
@@ -268,7 +274,7 @@ export class WebGPUBackend extends KernelBackend {
268274
this.tensorMap.set(dataId, {
269275
dtype,
270276
values,
271-
bufferInfo: {byteSize, usage: DEFAULT_GPUBUFFER_USAGE},
277+
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
272278
refCount
273279
});
274280
}

tfjs-backend-webgpu/src/index.ts

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,42 @@
1818
import './flags_webgpu';
1919
import './register_all_kernels';
2020

21-
import {env, registerBackend} from '@tensorflow/tfjs-core';
21+
import {device_util, env, registerBackend} from '@tensorflow/tfjs-core';
2222
import glslangInit from '@webgpu/glslang/dist/web-devel/glslang.onefile';
2323

2424
import {WebGPUBackend} from './backend_webgpu';
2525
import * as webgpu from './webgpu';
26-
27-
registerBackend('webgpu', async () => {
28-
// Remove it once we figure out how to correctly read the tensor data before
29-
// the tensor is disposed in profiling mode.
30-
env().set('CHECK_COMPUTATION_FOR_ERRORS', false);
31-
32-
const glslang = await glslangInit();
33-
const gpuDescriptor: GPURequestAdapterOptions = {
34-
powerPreference: env().get('WEBGPU_USE_LOW_POWER_GPU') ? 'low-power' :
35-
'high-performance'
36-
};
37-
38-
const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
39-
let deviceDescriptor: GPUDeviceDescriptor = {};
40-
const supportTimeQuery = adapter.features.has('timestamp-query');
41-
42-
if (supportTimeQuery) {
43-
deviceDescriptor = {
44-
nonGuaranteedFeatures: ['timestamp-query' as const]
26+
import {isWebGPUSupported} from './webgpu_util';
27+
28+
if (device_util.isBrowser() && isWebGPUSupported()) {
29+
registerBackend('webgpu', async () => {
30+
// Remove it once we figure out how to correctly read the tensor data
31+
// before the tensor is disposed in profiling mode.
32+
env().set('CHECK_COMPUTATION_FOR_ERRORS', false);
33+
34+
const glslang = await glslangInit();
35+
const gpuDescriptor: GPURequestAdapterOptions = {
36+
powerPreference: env().get('WEBGPU_USE_LOW_POWER_GPU') ?
37+
'low-power' :
38+
'high-performance'
4539
};
46-
} else {
47-
console.warn(
48-
`This device doesn't support timestamp-query extension. ` +
49-
`Zero will shown for the kernel time when profiling mode is enabled. ` +
50-
`Using performance.now is not workable for webgpu since it doesn't ` +
51-
`support synchronously to read data from GPU.`);
52-
}
53-
const device: GPUDevice = await adapter.requestDevice(deviceDescriptor);
54-
return new WebGPUBackend(device, glslang, supportTimeQuery);
55-
}, 3 /*priority*/);
40+
41+
const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
42+
let deviceDescriptor: GPUDeviceDescriptor = {};
43+
const supportTimeQuery = adapter.features.has('timestamp-query');
44+
45+
if (supportTimeQuery) {
46+
deviceDescriptor = {nonGuaranteedFeatures: ['timestamp-query' as const ]};
47+
} else {
48+
console.warn(
49+
`This device doesn't support timestamp-query extension. ` +
50+
`Zero will shown for the kernel time when profiling mode is` +
51+
`enabled. Using performance.now is not workable for webgpu since` +
52+
`it doesn't support synchronously to read data from GPU.`);
53+
}
54+
const device: GPUDevice = await adapter.requestDevice(deviceDescriptor);
55+
return new WebGPUBackend(device, glslang, supportTimeQuery);
56+
}, 3 /*priority*/);
57+
}
5658

5759
export {webgpu};

tfjs-backend-webgpu/src/webgpu_util.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,10 @@ export function ArrayBufferToTypedArray(data: ArrayBuffer, dtype: DataType) {
152152
throw new Error(`Unknown dtype ${dtype}`);
153153
}
154154
}
155+
156+
export function isWebGPUSupported(): boolean {
157+
if (!navigator.gpu) {
158+
return false;
159+
}
160+
return true;
161+
}

0 commit comments

Comments
 (0)