Skip to content

Commit 38265ac

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add Mul op for Vulkan (pytorch#47021)
Summary: Updates mul_scalar shader to support the new Vulkan API, and adds a new op for it using the new API. Also adds an in-place version for the op. Pull Request resolved: pytorch#47021 Test Plan: Unit test included. To build & run: ``` BUILD_CUSTOM_PROTOBUF=OFF \ BUILD_TEST=ON \ USE_EIGEN_FOR_BLAS=OFF \ USE_FBGEMM=OFF \ USE_MKLDNN=OFF \ USE_NNPACK=OFF \ USE_NUMPY=OFF \ USE_OBSERVERS=OFF \ USE_PYTORCH_QNNPACK=OFF \ USE_QNNPACK=OFF \ USE_VULKAN=ON \ USE_VULKAN_API=ON \ USE_VULKAN_SHADERC_RUNTIME=ON \ USE_VULKAN_WRAPPER=OFF \ MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python3 setup.py develop --cmake && ./build/bin/vulkan_api_test ``` Reviewed By: AshkanAliabadi Differential Revision: D24624729 Pulled By: SS-JIA fbshipit-source-id: 97e76e4060307a9a24311ac51dca8812e4471249
1 parent 2b6a720 commit 38265ac

File tree

5 files changed

+198
-13
lines changed

5 files changed

+198
-13
lines changed

aten/src/ATen/native/vulkan/VulkanOps.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -613,13 +613,12 @@ void mul(VulkanTensor& output, const VulkanTensor& input, const float s) {
613613

614614
auto device = context().device();
615615
struct ConstBlock {
616-
int32_t inputSize[4];
616+
int32_t inputSize[3];
617617
float s;
618618
};
619619
ConstBlock cb{{safe_downcast<int32_t>(W),
620620
safe_downcast<int32_t>(H),
621-
safe_downcast<int32_t>(C_4),
622-
0},
621+
safe_downcast<int32_t>(C_4)},
623622
s};
624623
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
625624

Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
#version 450 core
2+
#define PRECISION $precision
3+
24
layout(std430) buffer;
35
layout(std430) uniform;
46

5-
layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput;
6-
layout(set = 0, binding = 1) uniform highp sampler3D uInput;
7-
layout(set = 0, binding = 2) uniform constBlock {
8-
ivec4 sizes;
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
10+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
11+
layout(set = 0, binding = 2) uniform restrict Block {
12+
ivec3 WHC;
913
float other;
10-
}
11-
uConstBlock;
14+
} uBlock;
1215

1316
layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;
1417

1518
void main() {
16-
ivec3 pos = ivec3(gl_GlobalInvocationID);
17-
if (all(lessThan(pos, uConstBlock.sizes.xyz))) {
18-
vec4 v = uConstBlock.other * texelFetch(uInput, pos, 0);
19-
imageStore(uOutput, pos, v);
19+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
20+
21+
if (all(lessThan(pos, uBlock.WHC))) {
22+
imageStore(
23+
uOutput,
24+
pos,
25+
texelFetch(uInput, pos, 0) * uBlock.other);
2026
}
2127
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
4+
layout(std430) buffer;
5+
layout(std430) uniform;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
10+
layout(set = 0, binding = 1) uniform restrict Block {
11+
ivec3 WHC;
12+
float other;
13+
} uBlock;
14+
15+
layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;
16+
17+
void main() {
18+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
19+
20+
if (all(lessThan(pos, uBlock.WHC))) {
21+
imageStore(
22+
uOutput,
23+
pos,
24+
imageLoad(uOutput, pos) * uBlock.other);
25+
}
26+
}
+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#include <ATen/native/vulkan/ops/Common.h>
2+
#include <torch/library.h>
3+
4+
namespace at {
5+
namespace native {
6+
namespace vulkan {
7+
namespace ops {
8+
namespace {
9+
10+
Tensor mul_scalar(
11+
const Tensor& self_arg,
12+
const Scalar other) {
13+
api::Context* const context = api::context();
14+
15+
const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
16+
const vTensor& v_self = convert(self);
17+
18+
vTensor v_output{
19+
context,
20+
self.sizes(),
21+
self.options(),
22+
};
23+
24+
api::Command::Buffer command_buffer = context->command().pool.allocate();
25+
command_buffer.begin();
26+
{
27+
if (v_output.has_image() && v_self.has_image()) {
28+
const struct {
29+
uint32_t width, height, channels;
30+
float other;
31+
} block {
32+
v_output.extents().width,
33+
v_output.extents().height,
34+
v_output.extents().depth,
35+
other.to<float>(),
36+
};
37+
38+
context->dispatch(
39+
command_buffer,
40+
{
41+
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
42+
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
43+
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
44+
},
45+
VK_KERNEL(mul_scalar),
46+
v_output.extents(),
47+
// Write-only access bypasses synchronization but inserts appropriate
48+
// barriers if necessary.
49+
v_output.image(command_buffer, vTensor::Access::Write),
50+
// Read-only access is implied on const tensors and triggers an async
51+
// synchronization if necessary.
52+
v_self.image(command_buffer),
53+
// Object lifetime is managed by the resource pool.
54+
// It is OK not to keep track of the handle.
55+
context->resource().pool.uniform(block).object);
56+
}
57+
else {
58+
TORCH_CHECK(false, "Not implemented!");
59+
}
60+
}
61+
command_buffer.end();
62+
command_buffer.submit(context->gpu().queue);
63+
64+
return convert(v_output);
65+
}
66+
67+
Tensor& mul_scalar_(
68+
Tensor& self_arg,
69+
const Scalar other) {
70+
api::Context* const context = api::context();
71+
72+
TORCH_CHECK(
73+
self_arg.is_vulkan(),
74+
"Vulkan: In-place add is only supported on Vulkan tensors.");
75+
76+
vTensor& v_self = convert(self_arg);
77+
78+
api::Command::Buffer command_buffer = context->command().pool.allocate();
79+
command_buffer.begin();
80+
{
81+
if (v_self.has_image()) {
82+
const struct {
83+
uint32_t width, height, channels;
84+
float other;
85+
} block {
86+
v_self.extents().width,
87+
v_self.extents().height,
88+
v_self.extents().depth,
89+
other.to<float>(),
90+
};
91+
92+
context->dispatch(
93+
command_buffer,
94+
{
95+
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
96+
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
97+
},
98+
VK_KERNEL(mul_scalar_),
99+
v_self.extents(),
100+
// Read-Write access triggers an async synchronization if necessory
101+
// and inserts appropriate barriers if hazards are detected.
102+
v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write),
103+
// Object lifetime is managed by the resource pool.
104+
// It is OK not to keep track of the handle.
105+
context->resource().pool.uniform(block).object);
106+
}
107+
else {
108+
TORCH_CHECK(false, "Not implemented!");
109+
}
110+
}
111+
command_buffer.end();
112+
command_buffer.submit(context->gpu().queue);
113+
114+
return self_arg;
115+
}
116+
117+
#ifdef USE_VULKAN_API
118+
119+
TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
120+
m.impl("mul.Scalar", TORCH_FN(mul_scalar));
121+
m.impl("mul_.Scalar", TORCH_FN(mul_scalar_));
122+
}
123+
124+
#endif /* USE_VULKAN_API */
125+
126+
} // namespace
127+
} // namespace ops
128+
} // namespace vulkan
129+
} // namespace native
130+
} // namespace at

aten/src/ATen/test/vulkan_api_test.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,30 @@ TEST(VulkanAPITest, add_scalar_) {
7979
ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu()));
8080
}
8181

82+
TEST(VulkanAPITest, mul_scalar) {
83+
const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat));
84+
const auto a_vulkan = a_cpu.vulkan();
85+
86+
const float b_scalar = 3.1415f;
87+
88+
const auto c_cpu = at::mul(a_cpu, b_scalar);
89+
const auto c_vulkan = at::mul(a_vulkan, b_scalar);
90+
91+
ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu()));
92+
}
93+
94+
TEST(VulkanAPITest, mul_scalar_) {
95+
auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat));
96+
auto a_vulkan = a_cpu.vulkan();
97+
98+
const float b_scalar = 3.1415f;
99+
100+
a_cpu.mul_(b_scalar);
101+
a_vulkan.mul_(b_scalar);
102+
103+
ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu()));
104+
}
105+
82106
TEST(VulkanAPITest, copy) {
83107
const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat));
84108
ASSERT_TRUE(exactlyEqual(cpu, cpu.vulkan().cpu()));

0 commit comments

Comments
 (0)