-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix DLA clipping handling after codegen fix
- Loading branch information
1 parent
9e6fab4
commit 1ae5736
Showing
5 changed files
with
81 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
#![no_std] | ||
#![no_main] | ||
|
||
#[macro_use] | ||
extern crate alloc; | ||
use alloc::vec::Vec; | ||
use core::ffi::{c_char, CStr}; | ||
|
@@ -20,7 +21,6 @@ unsafe fn ffi_data_import( | |
input_height: usize, | ||
input_width: usize, | ||
input_order: *const c_char, | ||
input_zero: i16, | ||
kernel_data: *const i8, | ||
kernel_amount: usize, | ||
kernel_channels: usize, | ||
|
@@ -32,11 +32,6 @@ unsafe fn ffi_data_import( | |
slice::from_raw_parts(input_data, input_channels * input_height * input_width).to_vec() | ||
}; | ||
|
||
// Input zero point shift | ||
for x in input_data.iter_mut() { | ||
*x = i16::clamp((*x) as i16 + input_zero, i8::MIN as i16, i8::MAX as i16) as i8 | ||
} | ||
|
||
let input_order_string = unsafe { CStr::from_ptr(input_order).to_str().unwrap_unchecked() }; | ||
let input_tensor = unsafe { | ||
Tensor3::from_data_buffer( | ||
|
@@ -112,7 +107,6 @@ pub unsafe extern "C" fn dla_conv2d( | |
input_height, | ||
input_width, | ||
input_order, | ||
0, | ||
kernel_data, | ||
kernel_amount, | ||
kernel_channels, | ||
|
@@ -177,7 +171,6 @@ pub unsafe extern "C" fn dla_conv2d_relu( | |
input_height, | ||
input_width, | ||
input_order, | ||
0, | ||
kernel_data, | ||
kernel_amount, | ||
kernel_channels, | ||
|
@@ -248,7 +241,6 @@ pub unsafe extern "C" fn dla_conv2d_bias( | |
input_height, | ||
input_width, | ||
input_order, | ||
0, | ||
kernel_data, | ||
kernel_amount, | ||
kernel_channels, | ||
|
@@ -322,7 +314,6 @@ pub unsafe extern "C" fn dla_conv2d_bias_relu( | |
input_height, | ||
input_width, | ||
input_order, | ||
0, | ||
kernel_data, | ||
kernel_amount, | ||
kernel_channels, | ||
|
@@ -375,15 +366,11 @@ pub unsafe extern "C" fn dla_conv2d_bias_relu( | |
/// | ||
/// * `bias` - Buffer containing bias data. NOTE: Bias is actually i16 in hardware, here we use 32 for TVM compatibility | ||
#[no_mangle] | ||
pub unsafe extern "C" fn dla_tvm_qnn_conv2d( | ||
pub unsafe extern "C" fn dla_tvm_qnn_conv2d_bias( | ||
input_data: *const i8, | ||
kernel_data: *const i8, | ||
bias: *const i32, | ||
output: *mut i8, | ||
output_scale: *const f32, | ||
output_zero: *const i32, | ||
input_scale: *const f32, | ||
input_zero: *const i32, | ||
output: *mut i32, | ||
input_channels: usize, | ||
input_height: usize, | ||
input_width: usize, | ||
|
@@ -404,26 +391,13 @@ pub unsafe extern "C" fn dla_tvm_qnn_conv2d( | |
mac_clip: u32, | ||
pp_clip: u32, | ||
) { | ||
let input_scale: Vec<f32> = | ||
unsafe { slice::from_raw_parts(input_scale as *const f32, 1).to_vec() }; | ||
|
||
let input_zero: Vec<i32> = | ||
unsafe { slice::from_raw_parts(input_zero as *const i32, 1).to_vec() }; | ||
|
||
let output_scale: Vec<f32> = | ||
unsafe { slice::from_raw_parts(output_scale as *const f32, kernel_amount).to_vec() }; | ||
|
||
let output_zero: Vec<i32> = | ||
unsafe { slice::from_raw_parts(output_zero as *const i32, 1).to_vec() }; | ||
|
||
let (input_tensor, kernels_tensor) = unsafe { | ||
ffi_data_import( | ||
input_data, | ||
input_channels, | ||
input_height, | ||
input_width, | ||
input_order, | ||
(-1 * input_zero[0]) as i16, | ||
kernel_data, | ||
kernel_amount, | ||
kernel_channels, | ||
|
@@ -433,14 +407,17 @@ pub unsafe extern "C" fn dla_tvm_qnn_conv2d( | |
) | ||
}; | ||
|
||
// NOTE:(20241025 [email protected]) TVM expects 32-bit bias, but DLA only support 16-bit bias, so we clip the incoming bias | ||
// to range suitable for DLA | ||
let bias: Vec<i16> = unsafe { | ||
slice::from_raw_parts(bias as *const i32, bias_length) | ||
.into_iter() | ||
.map(|x| (*x).clamp(i16::MIN as i32, i16::MAX as i32) as i16) | ||
.collect() | ||
}; | ||
|
||
let mut result = conv2d_bias_relu( | ||
|
||
let mut result: Tensor3<i8> = conv2d_bias( | ||
input_tensor, | ||
kernels_tensor, | ||
bias, | ||
|
@@ -449,7 +426,7 @@ pub unsafe extern "C" fn dla_tvm_qnn_conv2d( | |
right: pad_right, | ||
left: pad_left, | ||
bottom: pad_bottom, | ||
padding_value: pad_value, | ||
padding_value: 0, | ||
}), | ||
Some(Stride { | ||
x: stride_x, | ||
|
@@ -460,28 +437,20 @@ pub unsafe extern "C" fn dla_tvm_qnn_conv2d( | |
None, | ||
); | ||
|
||
let input_order_string = unsafe { CStr::from_ptr(input_order).to_str().unwrap_unchecked() }; | ||
|
||
// TVM requantization and clip | ||
// NOTE:(20240927 [email protected]) on DLA clipping behaviour with TVM. | ||
// DLA's conv2d arithmetic is done at 16 bit width, but the output of the DLA is limited to 8 bits. | ||
// To comply with TVM's expected value range our solution is to bit shift/clip the 16-bit result of | ||
// conv2d by 8 bits and shift if back in the driver. This causes some amount of data loss due to | ||
// the lost granularity of the values. The clipping amount is set by the pp_clip argument. | ||
rescale( | ||
&mut result, | ||
u32::pow(2, pp_clip) as f32, | ||
input_zero[0], | ||
output_zero[0], | ||
input_scale[0], | ||
output_scale, | ||
); | ||
|
||
let input_order_string = unsafe { CStr::from_ptr(input_order).to_str().unwrap_unchecked() }; | ||
let mut res_i32: Vec<i32> = result.to_buffer_with_order(Order3::try_from(input_order_string).unwrap_unchecked()) | ||
.iter().map(|x: &i8| (*x as f32 * u32::pow(2, pp_clip) as f32) as i32).collect(); | ||
|
||
unsafe { | ||
core::ptr::copy_nonoverlapping( | ||
result | ||
.to_buffer_with_order(Order3::try_from(input_order_string).unwrap_unchecked()) | ||
.as_mut_ptr(), | ||
res_i32.as_mut_ptr(), | ||
output, | ||
result.get_size(), | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.