Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 202e7ec

Browse files
authored
webgl: add full texture cache (#159)
1 parent 09b3d36 commit 202e7ec

14 files changed

+140
-82
lines changed

lib/api/onnx.ts

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ export declare namespace Backend {
3030
* set or get the maximum batch size for matmul. 0 means to disable batching.
3131
*/
3232
matmulMaxBatchSize?: number;
33+
/**
34+
* set or get the texture cache mode
35+
*/
36+
textureCacheMode?: 'initializerOnly'|'full';
3337
}
3438
/**
3539
* set options for the WebAssembly backend

lib/backends/backend-webgl.ts

+4
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ export class WebGLBackend implements Backend, WebGLOptions {
2222
glContext: WebGLContext;
2323
contextId?: 'webgl'|'webgl2';
2424
matmulMaxBatchSize?: number;
25+
textureCacheMode?: 'initializerOnly'|'full';
2526

2627
initialize(): boolean {
2728
try {
2829
this.glContext = createWebGLContext(this.contextId);
2930
if (typeof this.matmulMaxBatchSize !== 'number') {
3031
this.matmulMaxBatchSize = 16;
3132
}
33+
if (typeof this.textureCacheMode !== 'string') {
34+
this.textureCacheMode = 'full';
35+
}
3236
Logger.verbose('WebGLBackend', `Created WebGLContext: ${typeof this.glContext}`);
3337
return true;
3438
} catch (e) {

lib/backends/webgl/inference-handler.ts

+9-16
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,27 @@ import {Tensor} from '../../tensor';
77
import {ShapeUtil} from '../../util';
88

99
import {WebGLUint8Encode} from './ops/uint8-encode';
10-
import {ProgramManager} from './program-manager';
1110
import {WebGLSessionHandler} from './session-handler';
1211
import {Encoder} from './texture-data-encoder';
13-
import {TextureHelper} from './texture-helper';
1412
import {WidthHeightPrefs} from './texture-layout-strategy';
1513
import {TextureData, TextureLayout, WebGLOperator} from './types';
1614
import {getPackedShape} from './utils';
1715

1816
export class WebGLInferenceHandler implements InferenceHandler {
19-
textureHelper: TextureHelper;
20-
programManager: ProgramManager;
2117
private textureDataCache: Map<Tensor.Id, TextureData>;
2218
constructor(public session: WebGLSessionHandler) {
23-
this.textureHelper = session.textureHelper;
24-
this.programManager = session.programManager;
2519
this.textureDataCache = new Map();
2620
}
2721

2822
run(op: WebGLOperator, inputs: Tensor[]): Tensor[] {
29-
let artifact = this.programManager.getArtifact(op);
23+
let artifact = this.session.programManager.getArtifact(op);
3024
if (!artifact) {
3125
const programInfo = op.createProgramInfo(this, inputs);
32-
artifact = this.programManager.build(programInfo);
33-
this.programManager.setArtifact(op, artifact);
26+
artifact = this.session.programManager.build(programInfo);
27+
this.session.programManager.setArtifact(op, artifact);
3428
}
3529
const runData = op.createRunData(this, artifact.programInfo, inputs);
36-
this.programManager.run(artifact, runData);
30+
this.session.programManager.run(artifact, runData);
3731
return [runData.outputTextureData.tensor];
3832
}
3933

@@ -90,7 +84,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
9084
layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor,
9185
usage?: Encoder.Usage): TextureData {
9286
Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
93-
const texture = this.textureHelper.createTextureFromLayout(dataType, layout, data, usage);
87+
const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage);
9488
return this.createTextureDataFromTexture(layout, dataType, texture, tensor);
9589
}
9690

@@ -175,18 +169,17 @@ export class WebGLInferenceHandler implements InferenceHandler {
175169
}
176170

177171
dispose(): void {
178-
this.textureHelper.clearActiveTextures();
179-
this.textureDataCache.forEach(td => this.textureHelper.releaseTexture(td));
172+
this.session.textureManager.clearActiveTextures();
173+
this.textureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
180174
this.textureDataCache = new Map();
181175
}
182176

183177
readTexture(textureData: TextureData): Tensor.NumberType {
184178
if (!this.session.backend.glContext.isFloat32DownloadSupported) {
185179
const op = new WebGLUint8Encode();
186180
const uint8TD = op.runInternal(this, textureData);
187-
return this.textureHelper.readUint8TextureAsFloat(uint8TD);
181+
return this.session.textureManager.readUint8TextureAsFloat(uint8TD);
188182
}
189-
const values = this.textureHelper.readTexture(textureData, textureData.tensor.type, textureData.channels);
190-
return values;
183+
return this.session.textureManager.readTexture(textureData, textureData.tensor.type, textureData.channels);
191184
}
192185
}

lib/backends/webgl/ops/conv.ts

+19-21
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ import {WebGLContext} from '../webgl-context';
1212

1313
export class WebGLConv extends Conv {
1414
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
15-
const programManager = inferenceHandler.programManager;
15+
const programManager = inferenceHandler.session.programManager;
1616
if (!this.artifacts) {
1717
this.artifacts = [];
1818
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
1919
for (let i = 0; i < programInfos.length; ++i) {
20-
const artifact = inferenceHandler.programManager.build(programInfos[i]);
20+
const artifact = inferenceHandler.session.programManager.build(programInfos[i]);
2121
this.artifacts.push(artifact);
2222
}
2323
}
@@ -70,40 +70,38 @@ export class WebGLConv extends Conv {
7070
inputTDs.push(inferenceHandler.getOrCreateTextureData(b));
7171
}
7272
const outputTD = inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, inputs[0].type);
73-
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
7473
const runDataDotProduct = {
7574
inputTextureDatas: inputTDs,
7675
outputTextureData: outputTD,
7776
uniformData: {},
78-
preRun: blendEnabled ?
79-
(glContext: WebGLContext, artifact: Artifact) => {
80-
const gl = glContext.gl;
81-
gl.enable(gl.BLEND);
82-
glContext.checkError();
83-
gl.blendEquation(gl.FUNC_ADD);
84-
glContext.checkError();
85-
gl.blendFunc(gl.ONE, gl.ONE);
86-
glContext.checkError();
87-
} :
88-
undefined,
89-
postRun: blendEnabled ?
90-
(glContext: WebGLContext, artifact: Artifact) => {
91-
const gl = glContext.gl;
92-
gl.disable(gl.BLEND);
93-
glContext.checkError();
94-
} :
95-
undefined,
9677
draw: (glContext: WebGLContext, artifact: Artifact) => {
9778
const gl = glContext.gl;
9879
const sharedDim = artifact.programInfo.params!.sharedDim as number;
9980
const sharedDimReadSize = artifact.programInfo.params!.sharedDimReadSize as number;
10081
const sharedDimOffsetLocation = artifact.uniformLocations.find(l => l.name === 'sharedDimOffset')!.location;
82+
let blend = false;
10183
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
10284
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);
85+
86+
if (k === sharedDimReadSize) {
87+
blend = true;
88+
gl.enable(gl.BLEND);
89+
glContext.checkError();
90+
gl.blendEquation(gl.FUNC_ADD);
91+
glContext.checkError();
92+
gl.blendFunc(gl.ONE, gl.ONE);
93+
glContext.checkError();
94+
}
95+
10396
gl.uniform1i(sharedDimOffsetLocation, k);
10497
glContext.checkError();
10598
glContext.draw();
10699
}
100+
101+
if (blend) {
102+
gl.disable(gl.BLEND);
103+
glContext.checkError();
104+
}
107105
}
108106
};
109107
return [runtDataIm2Col, runDataDotProduct];

lib/backends/webgl/ops/softmax.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ export class WebGLSoftmax extends Softmax {
1717
this.artifacts = [];
1818
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
1919
programInfos.forEach((pi, i) => {
20-
const artifact = inferenceHandler.programManager.build(pi);
20+
const artifact = inferenceHandler.session.programManager.build(pi);
2121
this.artifacts.push(artifact);
2222
});
2323
}
2424

2525
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
26-
runDatas.forEach((v, i) => inferenceHandler.programManager.run(this.artifacts[i], v));
26+
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v));
2727
// return only the last output
2828
return [runDatas[runDatas.length - 1].outputTextureData.tensor];
2929
}

lib/backends/webgl/ops/split.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ export class WebGLSplit extends Split {
1414
this.artifacts = [];
1515
for (let i = 0; i < count; ++i) {
1616
const programInfo = this.createProgramInfo(inferenceHandler, inputs[0], i);
17-
const artifact = inferenceHandler.programManager.build(programInfo);
17+
const artifact = inferenceHandler.session.programManager.build(programInfo);
1818
this.artifacts.push(artifact);
1919
}
2020
}
2121
const results: Tensor[] = [];
2222

2323
this.artifacts.forEach(artifact => {
2424
const rundata = this.createRunData(inferenceHandler, artifact.programInfo, inputs);
25-
inferenceHandler.programManager.run(artifact, rundata);
25+
inferenceHandler.session.programManager.run(artifact, rundata);
2626
results.push(rundata.outputTextureData.tensor);
2727
});
2828
return results;

lib/backends/webgl/ops/uint8-encode.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ export class WebGLUint8Encode {
7272
${glsl.output} = encodeAsUint8(value);
7373
}`;
7474
const programInfo = {inputLayouts: [input], outputLayout, samplers: ['X'], shaderSource, hasMain: true};
75-
const artifact = inferenceHandler.programManager.build(programInfo);
75+
const artifact = inferenceHandler.session.programManager.build(programInfo);
7676

7777
const encoder = inferenceHandler.session.backend.glContext.getEncoder('byte', 4);
7878
const texture =
7979
inferenceHandler.session.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, encoder);
8080
const outputTextureData = inferenceHandler.createSharedTextureData(outputLayout, 'uint8', texture, {});
8181
const runData = {inputTextureDatas: [input], outputTextureData, uniformData: {}};
8282

83-
inferenceHandler.programManager.run(artifact, runData);
83+
inferenceHandler.session.programManager.run(artifact, runData);
8484
return runData.outputTextureData;
8585
}
8686
}

lib/backends/webgl/program-manager.ts

-8
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ export class ProgramManager {
3535
}
3636
run(buildArtifact: Artifact, runData: RunData): void {
3737
this.profiler.event('backend', 'ProgramManager.run', () => {
38-
if (runData.preRun) {
39-
Logger.verbose('ProgramManager', 'PreRun');
40-
runData.preRun(this.glContext, buildArtifact);
41-
}
4238
const gl = this.glContext.gl;
4339
const program = buildArtifact.program;
4440
gl.useProgram(program);
@@ -56,10 +52,6 @@ export class ProgramManager {
5652
this.doDraw(buildArtifact, runData);
5753
gl.flush();
5854
});
59-
if (runData.postRun) {
60-
Logger.verbose('ProgramManager', 'PostRun');
61-
runData.postRun(this.glContext, buildArtifact);
62-
}
6355
});
6456
}
6557
dispose(): void {

lib/backends/webgl/session-handler.ts

+7-5
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,23 @@ import {WebGLBackend} from '../backend-webgl';
1313
import {WebGLInferenceHandler} from './inference-handler';
1414
import {WEBGL_OP_RESOLVE_RULES} from './op-resolve-rules';
1515
import {ProgramManager} from './program-manager';
16-
import {TextureHelper} from './texture-helper';
1716
import {AlwaysKeepOriginalSizeStrategy, TextureLayoutStrategy} from './texture-layout-strategy';
17+
import {TextureManager} from './texture-manager';
1818
import {TextureData} from './types';
1919

2020
export class WebGLSessionHandler implements SessionHandler {
2121
programManager: ProgramManager;
22-
textureHelper: TextureHelper;
22+
textureManager: TextureManager;
2323
layoutStrategy: TextureLayoutStrategy;
2424
textureDataCache: Map<Tensor.Id, TextureData>;
2525
initializers: Set<Tensor.Id>;
2626

2727
constructor(public readonly backend: WebGLBackend, public readonly context: Session.Context) {
2828
this.programManager = new ProgramManager(this.context.profiler, backend.glContext);
2929
this.layoutStrategy = new AlwaysKeepOriginalSizeStrategy(backend.glContext.maxTextureSize);
30-
this.textureHelper = new TextureHelper(backend.glContext, this.layoutStrategy, this.context.profiler);
30+
this.textureManager = new TextureManager(
31+
backend.glContext, this.layoutStrategy, this.context.profiler,
32+
{reuseTextures: backend.textureCacheMode === 'full'});
3133
this.textureDataCache = new Map();
3234
}
3335

@@ -50,8 +52,8 @@ export class WebGLSessionHandler implements SessionHandler {
5052
}
5153
dispose(): void {
5254
this.programManager.dispose();
53-
this.textureHelper.clearActiveTextures();
54-
this.textureDataCache.forEach(td => this.textureHelper.releaseTexture(td));
55+
this.textureManager.clearActiveTextures();
56+
this.textureDataCache.forEach(td => this.textureManager.releaseTexture(td, true));
5557
this.textureDataCache = new Map();
5658
}
5759
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {

0 commit comments

Comments
 (0)