Skip to content
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -142,7 +143,7 @@ void quantize_per_tensor() {
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T value = t_in[in_bufi];
OUT_T qvalue = quantize_val(value, scale, zero_point);
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);

t_out[out_bufi] = qvalue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -146,7 +147,7 @@ void quantize_per_tensor() {

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
OUT_T qvalue = quantize_val(value, scale, zero_point);
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
outtex[i] = qvalue;
}
write_texel(t_out, pos, outtex);
Expand Down
25 changes: 9 additions & 16 deletions backends/vulkan/runtime/graph/ops/impl/Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ void add_quantize_per_tensor_node(
add_dtype_suffix(kernel_name, graph.dtype_of(input));
add_dtype_suffix(kernel_name, graph.dtype_of(output));

float scale_val = static_cast<float>(graph.get_double(scale));
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
int quant_max_val = static_cast<int>(graph.get_int(quant_max));

Expand All @@ -102,23 +100,16 @@ void add_quantize_per_tensor_node(
graph.strides_ubo(input),
graph.sizes_ubo(output),
graph.strides_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
} else {
param_ubos = {
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
}

push_constants = {
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};

vkapi::SpecVarList spec_vars = {
graph.hashed_layout_of(output),
graph.hashed_layout_of(input),
Expand All @@ -130,7 +121,9 @@ void add_quantize_per_tensor_node(
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
{{output, vkapi::kWrite},
{input, vkapi::kRead},
{{scale, zero_point}, vkapi::kRead}},
// Shader param buffers
param_ubos,
// Push Constants
Expand Down Expand Up @@ -489,7 +482,7 @@ void quantize_per_channel_impl(

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.quantize_per_tensor.default,
quantized_decomposed.quantize_per_tensor.tensor,
quantize_per_tensor_impl);
VK_REGISTER_OP(
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);
Expand Down
Loading
Loading