Skip to content

Commit

Permalink
Typing IREE::HAL::DeviceTargetAttr executable targets. (iree-org#16588)
Browse files Browse the repository at this point in the history
This allows the device configuration to be independent from the
executable target attributes which is useful for devices that may share
the same configuration but different executable targets.
  • Loading branch information
benvanik authored Feb 28, 2024
1 parent a0febbe commit 4b1a4e2
Show file tree
Hide file tree
Showing 39 changed files with 251 additions and 298 deletions.
19 changes: 8 additions & 11 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,17 @@ class CUDATargetBackend final : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));
// TODO: device configuration attrs.

auto configAttr = b.getDictionaryAttr(configItems);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(getExecutableTarget(context));

return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);
context, b.getStringAttr(deviceID()), configAttr, targetAttrs);
}

void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
Expand Down Expand Up @@ -640,14 +645,6 @@ class CUDATargetBackend final : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
targetAttrs.push_back(getExecutableTarget(context));
return ArrayAttr::get(context, targetAttrs);
}

IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context) const {
Builder b(context);
Expand Down
8 changes: 3 additions & 5 deletions compiler/plugins/target/CUDA/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"cuda", {
executable_targets = [
#hal.executable.target<"cuda", "cuda-nvptx-fb">
]
}>
#hal.device.target<"cuda", [
#hal.executable.target<"cuda", "cuda-nvptx-fb">
]>
]
} {

Expand Down
21 changes: 8 additions & 13 deletions compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,16 @@ class MetalSPIRVTargetBackend : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

auto configAttr = b.getDictionaryAttr(configItems);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(
getExecutableTarget(context, getMetalTargetEnv(context)));

return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);
context, b.getStringAttr(deviceID()), configAttr, targetAttrs);
}

void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
Expand Down Expand Up @@ -282,15 +286,6 @@ class MetalSPIRVTargetBackend : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
targetAttrs.push_back(
getExecutableTarget(context, getMetalTargetEnv(context)));
return ArrayAttr::get(context, targetAttrs);
}

IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context,
spirv::TargetEnvAttr targetEnv) const {
Expand Down
12 changes: 5 additions & 7 deletions compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

module attributes {
hal.device.targets = [
#hal.device.target<"metal", {
executable_targets = [
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]
}>
#hal.device.target<"metal", [
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]
} {

Expand Down
19 changes: 7 additions & 12 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,15 @@ class ROCMTargetBackend final : public TargetBackend {
// synchronous mode.
configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

auto configAttr = b.getDictionaryAttr(configItems);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(getExecutableTarget(context));

return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);
context, b.getStringAttr(deviceID()), configAttr, targetAttrs);
}
// Performs optimizations on |module| (including LTO-style whole-program
// ones). Inspired by code section in
Expand Down Expand Up @@ -457,14 +460,6 @@ class ROCMTargetBackend final : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
targetAttrs.push_back(getExecutableTarget(context));
return ArrayAttr::get(context, targetAttrs);
}

IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context) const {
Builder b(context);
Expand Down
8 changes: 3 additions & 5 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"rocm", {
executable_targets = [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]
}>
#hal.device.target<"rocm", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]>
]
} {

Expand Down
40 changes: 16 additions & 24 deletions compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@ class VMVXTargetBackend final : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

auto configAttr = b.getDictionaryAttr(configItems);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(getVMVXExecutableTarget(
options.enableMicrokernels, context, "vmvx", "vmvx-bytecode-fb"));

return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);
context, b.getStringAttr(deviceID()), configAttr, targetAttrs);
}

std::optional<IREE::HAL::DeviceTargetAttr>
Expand Down Expand Up @@ -164,14 +168,6 @@ class VMVXTargetBackend final : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// This is where we would multiversion.
targetAttrs.push_back(getVMVXExecutableTarget(
options.enableMicrokernels, context, "vmvx", "vmvx-bytecode-fb"));
return ArrayAttr::get(context, targetAttrs);
}

const VMVXOptions &options;
};

Expand All @@ -191,12 +187,16 @@ class VMVXInlineTargetBackend final : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

auto configAttr = b.getDictionaryAttr(configItems);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(getVMVXExecutableTarget(
options.enableMicrokernels, context, "vmvx-inline", "vmvx-ir"));

return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);
context, b.getStringAttr(deviceID()), configAttr, targetAttrs);
}

void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
Expand All @@ -210,14 +210,6 @@ class VMVXInlineTargetBackend final : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// This is where we would multiversion.
targetAttrs.push_back(getVMVXExecutableTarget(
options.enableMicrokernels, context, "vmvx-inline", "vmvx-ir"));
return ArrayAttr::get(context, targetAttrs);
}

const VMVXOptions &options;
};

Expand Down
8 changes: 3 additions & 5 deletions compiler/plugins/target/VMVX/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"vmvx", {
executable_targets = [
#hal.executable.target<"vmvx", "vmvx-bytecode-fb">
]
}>
#hal.device.target<"vmvx", [
#hal.executable.target<"vmvx", "vmvx-bytecode-fb">
]>
]
} {

Expand Down
23 changes: 9 additions & 14 deletions compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ class WebGPUSPIRVTargetBackend : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

auto configAttr = b.getDictionaryAttr(configItems);
return IREE::HAL::DeviceTargetAttr::get(
context, b.getStringAttr(deviceID()), configAttr);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(
getExecutableTarget(context, getWebGPUTargetEnv(context)));

return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("webgpu"),
configAttr, targetAttrs);
}

void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
Expand Down Expand Up @@ -255,15 +259,6 @@ class WebGPUSPIRVTargetBackend : public TargetBackend {
}

private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
targetAttrs.push_back(
getExecutableTarget(context, getWebGPUTargetEnv(context)));
return ArrayAttr::get(context, targetAttrs);
}

IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context,
spirv::TargetEnvAttr targetEnv) const {
Expand Down
12 changes: 5 additions & 7 deletions compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
#map = affine_map<(d0) -> (d0)>
module attributes {
hal.device.targets = [
#hal.device.target<"webgpu", {
executable_targets = [
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]
}>
#hal.device.target<"webgpu", [
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]
} {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s

#device_target_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>]}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
#device_target_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>

hal.executable private @pad_matmul_static_dispatch_0 {
hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ util.func public @cmdExecute(%arg0: !stream.resource<transient>, %arg1: index, %

#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
#device_target_cpu = #hal.device.target<"llvm-cpu", {
executable_targets = [#executable_target_aarch64, #executable_target_x86_64]
}>
#device_target_cpu = #hal.device.target<"llvm-cpu", [
#executable_target_aarch64,
#executable_target_x86_64
]>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<4, storage_buffer>
Expand Down
Loading

0 comments on commit 4b1a4e2

Please sign in to comment.