Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 197 additions & 58 deletions src/CodeGen_ARM.cpp

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,10 @@ Type CodeGen_LLVM::upgrade_type_for_storage(const Type &t) const {
}
}

void CodeGen_LLVM::set_effective_vscale(int vscale) {
effective_vscale = vscale;
}

void CodeGen_LLVM::visit(const IntImm *op) {
value = ConstantInt::getSigned(llvm_type_of(op->type), op->value);
}
Expand Down Expand Up @@ -4643,6 +4647,12 @@ void CodeGen_LLVM::declare_intrin_overload(const std::string &name, const Type &
}

Value *CodeGen_LLVM::call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args) {
return call_overloaded_intrin(result_type, name, args, intrinsics);
}

Value *CodeGen_LLVM::call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args,
const IntrinsicsMap &overloaded_intrinsics) {

constexpr int debug_level = 4;

debug(debug_level) << "call_overloaded_intrin: " << result_type << " " << name << "(";
Expand All @@ -4653,8 +4663,8 @@ Value *CodeGen_LLVM::call_overloaded_intrin(const Type &result_type, const std::
}
debug(debug_level) << ")\n";

auto impls_i = intrinsics.find(name);
if (impls_i == intrinsics.end()) {
const auto impls_i = overloaded_intrinsics.find(name);
if (impls_i == overloaded_intrinsics.end()) {
debug(debug_level) << "No intrinsic " << name << "\n";
return nullptr;
}
Expand Down
11 changes: 9 additions & 2 deletions src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class CodeGen_LLVM : public IRVisitor {
* of functions as. */
virtual Type upgrade_type_for_argument_passing(const Type &) const;

void set_effective_vscale(int vscale);

std::unique_ptr<llvm::Module> module;
llvm::Function *function = nullptr;
llvm::LLVMContext *context = nullptr;
Expand Down Expand Up @@ -474,8 +476,9 @@ class CodeGen_LLVM : public IRVisitor {
: result_type(result_type), arg_types(std::move(arg_types)), impl(impl) {
}
};
using IntrinsicsMap = std::map<std::string, std::vector<Intrinsic>>;
/** Mapping of intrinsic functions to the various overloads implementing it. */
std::map<std::string, std::vector<Intrinsic>> intrinsics;
IntrinsicsMap intrinsics;

/** Get an LLVM intrinsic declaration. If it doesn't exist, it will be created. */
llvm::Function *get_llvm_intrin(const Type &ret_type, const std::string &name, const std::vector<Type> &arg_types, bool scalars_are_vectors = false);
Expand All @@ -484,7 +487,11 @@ class CodeGen_LLVM : public IRVisitor {
llvm::Function *declare_intrin_overload(const std::string &name, const Type &ret_type, const std::string &impl_name, std::vector<Type> arg_types, bool scalars_are_vectors = false);
void declare_intrin_overload(const std::string &name, const Type &ret_type, llvm::Function *impl, std::vector<Type> arg_types);
/** Call an overloaded intrinsic function. Returns nullptr if no suitable overload is found. */
llvm::Value *call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args);
virtual llvm::Value *call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args);
/** Call an overloaded intrinsic function. Returns nullptr if no suitable overload is found.
* Look up the given overloaded_intrinsics map for the corresponding intrin */
llvm::Value *call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args,
const IntrinsicsMap &overloaded_intrinsics);

/** Generate a call to a vector intrinsic or runtime inlined
* function. The arguments are sliced up into vectors of the width
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ tests(GROUPS correctness
extern_stage_on_device.cpp
extract_concat_bits.cpp
failed_unroll.cpp
fallback_vscale_sve.cpp
fast_trigonometric.cpp
fibonacci.cpp
fit_function.cpp
Expand Down
83 changes: 83 additions & 0 deletions test/correctness/fallback_vscale_sve.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "Halide.h"
#include <fstream>
#include <regex>

using namespace Halide;

bool compile_and_check_vscale(Func &f,
const std::string &name,
const Target &t,
int exp_vscale,
const std::string &exp_intrin) {

// Look into llvm-ir and check function attributes for vscale_range
auto llvm_file_name = name + ".ll";
f.compile_to_llvm_assembly(llvm_file_name, f.infer_arguments(), t);

Internal::assert_file_exists(llvm_file_name);
std::ifstream llvm_file;
llvm_file.open(llvm_file_name);
std::string line;
// Pattern to extract "n" and "m" in "vscale_range(n,m)"
std::regex vscale_regex(R"(vscale_range\(\s*([0-9]+)\s*,\s*([0-9]+)\s*\))");

int act_vscale = 0;
bool intrin_found = false;

while (getline(llvm_file, line)) {
// Check vscale_range
std::smatch match;
if (std::regex_search(line, match, vscale_regex) && match[1] == match[2]) {
act_vscale = std::stoi(match[1]);
}
// Check intrin
if (line.find(exp_intrin) != std::string::npos) {
intrin_found = true;
}
}

if (act_vscale != exp_vscale) {
printf("[%s] Found vscale_range %d, while expected %d\n", name.c_str(), act_vscale, exp_vscale);
return false;
}
if (!intrin_found) {
printf("[%s] Cannot find expected intrin %s\n", name.c_str(), exp_intrin.c_str());
return false;
}
return true;
}

Var x("x"), y("y");

bool test_vscale(int vectorization_factor, int vector_bits, int exp_vscale) {
Func f("f");
f(x, y) = absd(x, y);
f.compute_root().vectorize(x, vectorization_factor);

Target t("arm-64-linux-sve2-no_asserts-no_runtime-no_bounds_query");
t.vector_bits = vector_bits;

std::stringstream name;
name << "test_vscale_v" << vectorization_factor << "_vector_bits_" << vector_bits;

// sve or neon
std::string intrin = exp_vscale > 0 ? "llvm.aarch64.sve.sabd" : "llvm.aarch64.neon.sabd";

return compile_and_check_vscale(f, name.str(), t, exp_vscale, intrin);
}

int main(int argc, char **argv) {

bool ok = true;

ok &= test_vscale(4, 128, 1); // Regular case: <vscale x 4 x ty> with vscale=1
ok &= test_vscale(3, 128, 0); // Fallback due to odd vectorization factor
ok &= test_vscale(8, 512, 4); // Regular case: <vscale x 2 x ty> with vscale=4
ok &= test_vscale(4, 512, 0); // Fallback due to <vscale x 1 x ty>

if (!ok) {
return 1;
}
printf("Success!\n");
return 0;
}
1 change: 1 addition & 0 deletions test/warning/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ tests(GROUPS warning
require_const_false.cpp
sliding_vectors.cpp
unscheduled_update_def.cpp
unsupported_vectorization_sve.cpp
emulated_float16.cpp
)

Expand Down
23 changes: 23 additions & 0 deletions test/warning/unsupported_vectorization_sve.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "Halide.h"
#include "halide_test_dirs.h"

using namespace Halide;

int main(int argc, char **argv) {
Func f;
Var x;

f(x) = x * 0.1f;

constexpr int vscale = 2;
constexpr int vector_bits = 128 * vscale;

f.vectorize(x, vscale * 3);
Target t("arm-64-linux-sve2-vector_bits_" + std::to_string(vector_bits));

// SVE is disabled with user_warning,
// which would have ended up with emitting <vscale x 3 x float> if we didn't.
f.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "unused.ll", f.infer_arguments(), "f", t);

return 0;
}
Loading