Skip to content

Commit 1b85992

Browse files
shubhraprakash1pytorchmergebot
authored andcommitted
Optimize quantized max pool 2d (pytorch#115690)
Summary: We do not need to dequantize and quantize again for this op. With this optimization cunet-enc ops: vulkan.quantized_max_pool2d_quint8{48, 36, 2} 207532 vulkan.quantized_max_pool2d_quint8{24, 18, 4} 78832 vulkan.quantized_max_pool2d_quint8{12, 9, 8} 49296 Without optimization: vulkan.quantized_max_pool2d_quint8{48, 36, 2} 234416 vulkan.quantized_max_pool2d_quint8{24, 18, 4} 94380 vulkan.quantized_max_pool2d_quint8{12, 9, 8} 58760 Test Plan: Ensure all vulkan quantize tests pass: buck2 run --target-platforms ovr_configplatform/macos:arm64-fbsourcexplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output" Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc [==========] Running 78 tests from 1 test suite. [----------] Global test environment set-up. [----------] 78 tests from VulkanAPITest ... [==========] 78 tests from 1 test suite ran. (1519 ms total) [ PASSED ] 78 tests. buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output" Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc [==========] Running 395 tests from 1 test suite. [----------] Global test environment set-up. [----------] 395 tests from VulkanAPITest ... [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms) [----------] 395 tests from VulkanAPITest (6515 ms total) [----------] Global test environment tear-down [==========] 395 tests from 1 test suite ran. (6515 ms total) [ PASSED ] 394 tests. [ SKIPPED ] 1 test, listed below: [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log YOU HAVE 5 DISABLED TESTS Reviewed By: yipjustin, copyrightly Differential Revision: D50998619 Pull Request resolved: pytorch#115690 Approved by: https://github.com/SS-JIA
1 parent 6fee208 commit 1b85992

File tree

3 files changed

+31
-85
lines changed

3 files changed

+31
-85
lines changed

aten/src/ATen/native/vulkan/glsl/quantized_max_pool2d_qint8.glsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ layout(set = 0, binding = 2) uniform PRECISION restrict Block
1414
ivec2 stride;
1515
ivec2 padding;
1616
ivec2 dilate;
17-
vec2 scale;
18-
ivec2 zero_point;
1917
} uBlock;
2018

2119
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -37,13 +35,11 @@ void main() {
3735
for (int x = start.x; x < end.x; x += uBlock.dilate.x) {
3836
if ((x >= 0 && x < uBlock.kernel.z) && (y >= 0 && y < uBlock.kernel.w)) {
3937
vec4 outtexy = texelFetch(uInput, ivec3(x, y, pos.z), 0);
40-
outtexy = uBlock.scale.x * (outtexy - uBlock.zero_point.x);
4138
outtex = max(outtexy, outtex);
4239
}
4340
}
4441
}
4542

46-
outtex = roundEven(outtex / uBlock.scale.x) + uBlock.zero_point.x;
4743
ivec4 store = ivec4(outtex);
4844
imageStore(uOutput, pos, store);
4945
}

aten/src/ATen/native/vulkan/glsl/quantized_max_pool2d_quint8.glsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ layout(set = 0, binding = 2) uniform PRECISION restrict Block
1414
ivec2 stride;
1515
ivec2 padding;
1616
ivec2 dilate;
17-
vec2 scale;
18-
ivec2 zero_point;
1917
} uBlock;
2018

2119
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -37,13 +35,11 @@ void main() {
3735
for (int x = start.x; x < end.x; x += uBlock.dilate.x) {
3836
if ((x >= 0 && x < uBlock.kernel.z) && (y >= 0 && y < uBlock.kernel.w)) {
3937
vec4 outtexy = texelFetch(uInput, ivec3(x, y, pos.z), 0);
40-
outtexy = uBlock.scale.x * (outtexy - uBlock.zero_point.x);
4138
outtex = max(outtexy, outtex);
4239
}
4340
}
4441
}
4542

46-
outtex = roundEven(outtex / uBlock.scale.x) + uBlock.zero_point.x;
4743
uvec4 store = uvec4(outtex);
4844
imageStore(uOutput, pos, store);
4945
}

aten/src/ATen/native/vulkan/ops/Pool.cpp

Lines changed: 31 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -168,83 +168,37 @@ Tensor pool2d(
168168
}
169169

170170
api::UniformParamsBuffer params;
171-
if (v_self.is_quantized()) {
172-
const struct Block final {
173-
uvec3 extents;
174-
int32_t range;
175-
ivec4 kernel;
176-
ivec2 stride;
177-
ivec2 padding;
178-
ivec2 dilation;
179-
vec2 scale;
180-
ivec2 zero_point;
181-
} block{
182-
v_output.extents(),
183-
safe_downcast<int32_t>(
184-
kernel[Layout::Parameter::width] *
185-
kernel[Layout::Parameter::height]),
186-
{
187-
safe_downcast<int32_t>(kernel[Layout::Parameter::width]),
188-
safe_downcast<int32_t>(kernel[Layout::Parameter::height]),
189-
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::width)),
190-
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::height)),
191-
},
192-
{
193-
safe_downcast<int32_t>(stride[Layout::Parameter::width]),
194-
safe_downcast<int32_t>(stride[Layout::Parameter::height]),
195-
},
196-
{
197-
safe_downcast<int32_t>(padding[Layout::Parameter::width]),
198-
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
199-
},
200-
{
201-
safe_downcast<int32_t>(dilation[Layout::Parameter::width]),
202-
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
203-
},
204-
{
205-
safe_downcast<float>(v_self.get_scale()),
206-
0.0f,
207-
},
208-
{
209-
safe_downcast<int32_t>(v_self.get_zero_point()),
210-
0u,
211-
},
212-
};
213-
params = api::UniformParamsBuffer(context, block);
214-
} else {
215-
const struct Block final {
216-
uvec3 extents;
217-
int32_t range;
218-
ivec4 kernel;
219-
ivec2 stride;
220-
ivec2 padding;
221-
ivec2 dilation;
222-
} block{
223-
v_output.extents(),
224-
safe_downcast<int32_t>(
225-
kernel[Layout::Parameter::width] *
226-
kernel[Layout::Parameter::height]),
227-
{
228-
safe_downcast<int32_t>(kernel[Layout::Parameter::width]),
229-
safe_downcast<int32_t>(kernel[Layout::Parameter::height]),
230-
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::width)),
231-
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::height)),
232-
},
233-
{
234-
safe_downcast<int32_t>(stride[Layout::Parameter::width]),
235-
safe_downcast<int32_t>(stride[Layout::Parameter::height]),
236-
},
237-
{
238-
safe_downcast<int32_t>(padding[Layout::Parameter::width]),
239-
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
240-
},
241-
{
242-
safe_downcast<int32_t>(dilation[Layout::Parameter::width]),
243-
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
244-
},
245-
};
246-
params = api::UniformParamsBuffer(context, block);
247-
}
171+
const struct Block final {
172+
uvec3 extents;
173+
int32_t range;
174+
ivec4 kernel;
175+
ivec2 stride;
176+
ivec2 padding;
177+
ivec2 dilation;
178+
} block{
179+
v_output.extents(),
180+
safe_downcast<int32_t>(
181+
kernel[Layout::Parameter::width] * kernel[Layout::Parameter::height]),
182+
{
183+
safe_downcast<int32_t>(kernel[Layout::Parameter::width]),
184+
safe_downcast<int32_t>(kernel[Layout::Parameter::height]),
185+
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::width)),
186+
safe_downcast<int32_t>(self_arg.size(Layout::Activation4D::height)),
187+
},
188+
{
189+
safe_downcast<int32_t>(stride[Layout::Parameter::width]),
190+
safe_downcast<int32_t>(stride[Layout::Parameter::height]),
191+
},
192+
{
193+
safe_downcast<int32_t>(padding[Layout::Parameter::width]),
194+
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
195+
},
196+
{
197+
safe_downcast<int32_t>(dilation[Layout::Parameter::width]),
198+
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
199+
},
200+
};
201+
params = api::UniformParamsBuffer(context, block);
248202

249203
api::PipelineBarrier pipeline_barrier{};
250204

0 commit comments

Comments
 (0)