Skip to content

Commit ab39dd3

Browse files
OuadiElfaroukiggerganov
authored andcommitted
Updated SYCL device filtering (llama/8901)
* Updated device filter to depend on default_selector (fixes non-intel device issues) * Small related update to example/sycl Readme
1 parent b1348d3 commit ab39dd3

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

ggml/src/ggml-sycl/dpct/helper.hpp

+16-3
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ namespace dpct
874874
inline std::string get_preferred_gpu_platform_name() {
875875
std::string result;
876876
877-
std::string filter = "level-zero";
877+
std::string filter = "";
878878
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
879879
if (env) {
880880
if (std::strstr(env, "level_zero")) {
@@ -892,11 +892,24 @@ namespace dpct
892892
else {
893893
throw std::runtime_error("invalid device filter: " + std::string(env));
894894
}
895+
} else {
896+
auto default_device = sycl::device(sycl::default_selector_v);
897+
auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
898+
899+
if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
900+
filter = "level-zero";
901+
}
902+
else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
903+
filter = "cuda";
904+
}
905+
else if (std::strstr(default_platform_name.c_str(), "HIP")) {
906+
filter = "hip";
907+
}
895908
}
896909
897-
auto plaform_list = sycl::platform::get_platforms();
910+
auto platform_list = sycl::platform::get_platforms();
898911
899-
for (const auto& platform : plaform_list) {
912+
for (const auto& platform : platform_list) {
900913
auto devices = platform.get_devices();
901914
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
902915
return d.is_gpu();

0 commit comments

Comments
 (0)