Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ void load_vector(const T1& data_lane,
const Xbyak_aarch64::XReg& ptr_reg,
const int64_t offset,
const bool broadcast,
jit_generator* h) {
jit_generator* h,
const size_t lane_count = 0) {
if (broadcast) {
if (offset == 0) {
h->ld1r(data_lane, ptr(ptr_reg));
Expand All @@ -334,14 +335,27 @@ void load_vector(const T1& data_lane,
h->ld1r(data_lane, ptr(h->X_DEFAULT_ADDR));
}
} else {
if (offset == 0) {
h->ld1(data_lanes, ptr(ptr_reg));
if (lane_count == 0) {
if (offset == 0) {
h->ld1(data_lanes, ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1(data_lanes, ptr(h->X_DEFAULT_ADDR));
}
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1(data_lanes, ptr(h->X_DEFAULT_ADDR));
for (size_t lane = 0; lane < lane_count; ++lane) {
const auto lane_offset = offset + static_cast<int64_t>(lane);
if (lane_offset == 0) {
h->ld1(data_lane[static_cast<int>(lane)], ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, lane_offset, h->X_TMP_0);
h->ld1(data_lane[static_cast<int>(lane)], ptr(h->X_DEFAULT_ADDR));
}
}
}
}
}

} // namespace utils

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
Expand All @@ -365,16 +379,19 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
}
break;
}
case ov::element::i8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
break;
}
case ov::element::i8:
case ov::element::u8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
// Stability-first: always lane-wise for i8/u8 to avoid crossing boundaries in tails.
const size_t lane_count = cpu_isa_traits<isa>::vlen / dst_prc.size();
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this, lane_count);

if (src_prc == ov::element::i8) {
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
} else {
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
}
break;
}
default: {
Expand Down Expand Up @@ -532,7 +549,18 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
}
case ov::element::i8:
case ov::element::u8: {
str(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
// Safe path: always lane-wise for i8/u8
const size_t lane_count = cpu_isa_traits<isa>::vlen / src_prc.size();
auto data_bytes = data;
for (size_t lane = 0; lane < lane_count; ++lane) {
const auto lane_offset = ptr_offset + static_cast<int32_t>(lane);
if (lane_offset == 0) {
st1(data_bytes.b[static_cast<int>(lane)], Xbyak_aarch64::ptr(ptr));
} else {
add_imm(X_DEFAULT_ADDR, ptr, lane_offset, X_TMP_0);
st1(data_bytes.b[static_cast<int>(lane)], Xbyak_aarch64::ptr(X_DEFAULT_ADDR));
}
}
break;
}
default: {
Expand Down
Loading