Skip to content

Ability to set external instance + devices #11393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/vk_api/Runtime.h>

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
Expand Down Expand Up @@ -528,7 +530,9 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
return Error::MemoryAllocationFailed;
}

new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
GraphConfig graph_config = get_graph_config(compile_specs);
graph_config.external_adapter = vkapi::set_and_get_external_adapter();
new (compute_graph) ComputeGraph(graph_config);

Error err = compileModel(processed->data(), compute_graph);

Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
namespace vkcompute {
namespace api {

Context::Context(size_t adapter_i, const ContextConfig& config)
Context::Context(vkapi::Adapter* adapter, const ContextConfig& config)
: config_(config),
// Important handles
adapter_p_(vkapi::runtime()->get_adapter_p(adapter_i)),
adapter_p_(adapter),
device_(adapter_p_->device_handle()),
queue_(adapter_p_->request_queue()),
// Resource pools
Expand Down Expand Up @@ -256,7 +256,7 @@ Context* context() {
query_pool_config,
};

return new Context(vkapi::runtime()->default_adapter_i(), config);
return new Context(vkapi::runtime()->get_adapter_p(), config);
} catch (...) {
}

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct ContextConfig final {

class Context final {
public:
explicit Context(size_t adapter_i, const ContextConfig&);
explicit Context(vkapi::Adapter*, const ContextConfig&);

Context(const Context&) = delete;
Context& operator=(const Context&) = delete;
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ ComputeGraph::ComputeGraph(GraphConfig config)
prepack_descriptor_counts_{},
execute_descriptor_counts_{},
context_{new api::Context(
vkapi::runtime()->default_adapter_i(),
config.external_adapter ? config.external_adapter
: vkapi::runtime()->get_adapter_p(),
config_.context_config)},
shared_objects_{},
values_{},
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/GraphConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ GraphConfig::GraphConfig() {
local_wg_size_override = {};

expect_dynamic_shapes = false;

external_adapter = nullptr;
}

void GraphConfig::set_storage_type_override(utils::StorageType storage_type) {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/GraphConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct GraphConfig final {
// Whether or not the ComputeGraph should expect input shapes to be dynamic
bool expect_dynamic_shapes;

vkapi::Adapter* external_adapter;

// Generate a default graph config with pre-configured settings
explicit GraphConfig();

Expand Down
155 changes: 106 additions & 49 deletions backends/vulkan/runtime/vk_api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@ namespace vkapi {

namespace {

VkDevice create_logical_device(
void find_compute_queues(
const PhysicalDevice& physical_device,
const uint32_t num_queues_to_create,
std::vector<Adapter::Queue>& queues,
std::vector<uint32_t>& queue_usage) {
// Find compute queues up to the requested number of queues

std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
std::vector<VkDeviceQueueCreateInfo>& queue_create_infos,
std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get) {
queue_create_infos.reserve(num_queues_to_create);

std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
queues_to_get.reserve(num_queues_to_create);

uint32_t remaining_queues = num_queues_to_create;
Expand Down Expand Up @@ -60,12 +55,44 @@ VkDevice create_logical_device(
break;
}
}
}

void populate_queue_info(
const PhysicalDevice& physical_device,
VkDevice logical_device,
const std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get,
std::vector<Adapter::Queue>& queues,
std::vector<uint32_t>& queue_usage) {
queues.reserve(queues_to_get.size());
queue_usage.reserve(queues_to_get.size());

// Create the VkDevice
// Obtain handles for the created queues and initialize queue usage heuristic

for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
VkQueue queue_handle = VK_NULL_HANDLE;
VkQueueFlags flags =
physical_device.queue_families.at(queue_idx.first).queueFlags;
vkGetDeviceQueue(
logical_device, queue_idx.first, queue_idx.second, &queue_handle);
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
// Initial usage value
queue_usage.push_back(0);
}
}

VkDevice create_logical_device(
const PhysicalDevice& physical_device,
const uint32_t num_queues_to_create,
std::vector<Adapter::Queue>& queues,
std::vector<uint32_t>& queue_usage) {
// Find compute queues up to the requested number of queues

std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
find_compute_queues(
physical_device, num_queues_to_create, queue_create_infos, queues_to_get);

// Create the VkDevice
std::vector<const char*> requested_device_extensions{
#ifdef VK_KHR_portability_subset
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
Expand Down Expand Up @@ -143,19 +170,42 @@ VkDevice create_logical_device(
volkLoadDevice(handle);
#endif /* USE_VULKAN_VOLK */

// Obtain handles for the created queues and initialize queue usage heuristic
populate_queue_info(
physical_device, handle, queues_to_get, queues, queue_usage);

for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
VkQueue queue_handle = VK_NULL_HANDLE;
VkQueueFlags flags =
physical_device.queue_families.at(queue_idx.first).queueFlags;
vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
// Initial usage value
queue_usage.push_back(0);
return handle;
}

bool test_linear_tiling_3d_image_support(VkDevice device) {
// Test creating a 3D image with linear tiling to see if it is supported.
// According to the Vulkan spec, linear tiling may not be supported for 3D
// images.
VkExtent3D image_extents{1u, 1u, 1u};
const VkImageCreateInfo image_create_info{
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
VK_IMAGE_TYPE_3D, // imageType
VK_FORMAT_R32G32B32A32_SFLOAT, // format
image_extents, // extents
1u, // mipLevels
1u, // arrayLayers
VK_SAMPLE_COUNT_1_BIT, // samples
VK_IMAGE_TILING_LINEAR, // tiling
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
0u, // queueFamilyIndexCount
nullptr, // pQueueFamilyIndices
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
};
VkImage image = VK_NULL_HANDLE;
VkResult res = vkCreateImage(device, &image_create_info, nullptr, &image);

if (res == VK_SUCCESS) {
vkDestroyImage(device, image, nullptr);
}

return handle;
return res == VK_SUCCESS;
}

} // namespace
Expand Down Expand Up @@ -186,37 +236,44 @@ Adapter::Adapter(
compute_pipeline_cache_(device_.handle, cache_data_path),
sampler_cache_(device_.handle),
vma_(instance_, physical_device_.handle, device_.handle),
linear_tiling_3d_enabled_{true} {
// Test creating a 3D image with linear tiling to see if it is supported.
// According to the Vulkan spec, linear tiling may not be supported for 3D
// images.
VkExtent3D image_extents{1u, 1u, 1u};
const VkImageCreateInfo image_create_info{
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
VK_IMAGE_TYPE_3D, // imageType
VK_FORMAT_R32G32B32A32_SFLOAT, // format
image_extents, // extents
1u, // mipLevels
1u, // arrayLayers
VK_SAMPLE_COUNT_1_BIT, // samples
VK_IMAGE_TILING_LINEAR, // tiling
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
0u, // queueFamilyIndexCount
nullptr, // pQueueFamilyIndices
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
};
VkImage image = VK_NULL_HANDLE;
VkResult res =
vkCreateImage(device_.handle, &image_create_info, nullptr, &image);
if (res != VK_SUCCESS) {
linear_tiling_3d_enabled_ = false;
} else {
vkDestroyImage(device_.handle, image, nullptr);
linear_tiling_3d_enabled_{
test_linear_tiling_3d_image_support(device_.handle)},
owns_device_{true} {}

Adapter::Adapter(
VkInstance instance,
VkPhysicalDevice physical_device,
VkDevice logical_device,
const uint32_t num_queues,
const std::string& cache_data_path)
: queue_usage_mutex_{},
physical_device_(physical_device),
queues_{},
queue_usage_{},
queue_mutexes_{},
instance_(instance),
device_(logical_device),
shader_layout_cache_(device_.handle),
shader_cache_(device_.handle),
pipeline_layout_cache_(device_.handle),
compute_pipeline_cache_(device_.handle, cache_data_path),
sampler_cache_(device_.handle),
vma_(instance_, physical_device_.handle, device_.handle),
linear_tiling_3d_enabled_{
test_linear_tiling_3d_image_support(device_.handle)},
owns_device_{false} {
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
find_compute_queues(
physical_device_, num_queues, queue_create_infos, queues_to_get);
populate_queue_info(
physical_device_, device_.handle, queues_to_get, queues_, queue_usage_);
}

Adapter::~Adapter() {
if (!owns_device_) {
device_.handle = VK_NULL_HANDLE;
}
return;
}

Adapter::Queue Adapter::request_queue() {
Expand Down
10 changes: 9 additions & 1 deletion backends/vulkan/runtime/vk_api/Adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,20 @@ class Adapter final {
const uint32_t num_queues,
const std::string& cache_data_path);

explicit Adapter(
VkInstance instance,
VkPhysicalDevice physical_device,
VkDevice logical_device,
const uint32_t num_queues,
const std::string& cache_data_path);

Adapter(const Adapter&) = delete;
Adapter& operator=(const Adapter&) = delete;

Adapter(Adapter&&) = delete;
Adapter& operator=(Adapter&&) = delete;

~Adapter() = default;
~Adapter();

struct Queue {
uint32_t family_index;
Expand Down Expand Up @@ -94,6 +101,7 @@ class Adapter final {
Allocator vma_;
// Miscellaneous
bool linear_tiling_3d_enabled_;
bool owns_device_;

public:
// Physical Device metadata
Expand Down
36 changes: 36 additions & 0 deletions backends/vulkan/runtime/vk_api/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#include <iostream>
#include <sstream>

#ifdef USE_VOLK_HEADER_ONLY
// For volk.h, define this before including volk.h in exactly one CPP file.
#define VOLK_IMPLEMENTATION
#include <volk.h>
#endif /* USE_VOLK_HEADER_ONLY */

namespace vkcompute {
namespace vkapi {

Expand Down Expand Up @@ -409,5 +415,35 @@ Runtime* runtime() {
return p_runtime.get();
}

std::unique_ptr<Adapter> init_external_adapter(
const VkInstance instance,
const VkPhysicalDevice physical_device,
const VkDevice logical_device,
const uint32_t num_queues,
const std::string& cache_data_path) {
if (instance == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE ||
logical_device == VK_NULL_HANDLE) {
return std::unique_ptr<Adapter>(nullptr);
}

return std::make_unique<Adapter>(
instance, physical_device, logical_device, num_queues, cache_data_path);
}

Adapter* set_and_get_external_adapter(
const VkInstance instance,
const VkPhysicalDevice physical_device,
const VkDevice logical_device) {
static const std::unique_ptr<Adapter> p_external_adapter =
init_external_adapter(
instance,
physical_device,
logical_device,
1,
set_and_get_pipeline_cache_data_path(""));

return p_external_adapter.get();
}

} // namespace vkapi
} // namespace vkcompute
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/vk_api/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,11 @@ std::string& set_and_get_pipeline_cache_data_path(const std::string& file_path);
// a static local variable.
Runtime* runtime();

// Used to share instance + devices between client code and ETVK
Adapter* set_and_get_external_adapter(
const VkInstance instance = VK_NULL_HANDLE,
const VkPhysicalDevice physical_device = VK_NULL_HANDLE,
const VkDevice logical_device = VK_NULL_HANDLE);

} // namespace vkapi
} // namespace vkcompute
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/vk_api/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
#include <cstddef>
#include <cstdint>

// X11 headers via volk define Bool, so we need to undef it
#if defined(__linux__)
#undef Bool
#endif

#ifdef USE_VULKAN_FP16_INFERENCE
#define VK_FORMAT_FLOAT4 VK_FORMAT_R16G16B16A16_SFLOAT
#else
Expand Down
Loading
Loading