Skip to content

Commit 263e0fa

Browse files
committed
[metal] Metal compute shader passthrough
1 parent a13f0a0 commit 263e0fa

File tree

16 files changed

+251
-8
lines changed

16 files changed

+251
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
186186
187187
- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) .
188188
189+
- Add Metal compute shader passthrough. Use `create_shader_module_msl` on device. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
190+
189191
#### Naga
190192
191193
- Add support for unsigned types when calling textureLoad with the level parameter. By @ygdrasil-io in [#7058](https://github.com/gfx-rs/wgpu/pull/7058).

examples/standalone/03_custom_backend/src/custom.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ impl DeviceInterface for CustomDevice {
133133
unimplemented!()
134134
}
135135

136+
unsafe fn create_shader_module_msl(
137+
&self,
138+
_desc: &wgpu::ShaderModuleDescriptorMsl<'_>,
139+
) -> DispatchShaderModule {
140+
unimplemented!()
141+
}
142+
136143
fn create_bind_group_layout(
137144
&self,
138145
_desc: &wgpu::BindGroupLayoutDescriptor<'_>,

naga/src/valid/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ bitflags::bitflags! {
235235
}
236236
}
237237

238-
#[derive(Debug, Clone)]
238+
#[derive(Debug, Clone, Default)]
239239
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
240240
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
241241
pub struct ModuleInfo {

wgpu-core/src/device/global.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,46 @@ impl Global {
987987
(id, Some(error))
988988
}
989989

990+
// Unsafe-ness of internal calls has little to do with unsafe-ness of this.
991+
#[allow(unused_unsafe)]
992+
/// # Safety
993+
///
994+
/// This function passes MSL source code to the backend as-is and can potentially result in a
995+
/// driver crash.
996+
pub unsafe fn device_create_shader_module_msl(
997+
&self,
998+
device_id: DeviceId,
999+
desc: &pipeline::ShaderModuleDescriptor,
1000+
source: Cow<str>,
1001+
entry_point: &str,
1002+
num_workgroups: (u32, u32, u32),
1003+
id_in: Option<id::ShaderModuleId>,
1004+
) -> (
1005+
id::ShaderModuleId,
1006+
Option<pipeline::CreateShaderModuleError>,
1007+
) {
1008+
profiling::scope!("Device::create_shader_module");
1009+
1010+
let hub = &self.hub;
1011+
let fid = hub.shader_modules.prepare(id_in);
1012+
1013+
let error = 'error: {
1014+
let device = self.hub.devices.get(device_id);
1015+
let shader = match unsafe {
1016+
device.create_shader_module_msl(desc, &source, entry_point, num_workgroups)
1017+
} {
1018+
Ok(shader) => shader,
1019+
Err(e) => break 'error e,
1020+
};
1021+
let id = fid.assign(Fallible::Valid(shader));
1022+
api_log!("Device::create_shader_module_spirv -> {id:?}");
1023+
return (id, None);
1024+
};
1025+
1026+
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string())));
1027+
(id, Some(error))
1028+
}
1029+
9901030
pub fn shader_module_drop(&self, shader_module_id: id::ShaderModuleId) {
9911031
profiling::scope!("ShaderModule::drop");
9921032
api_log!("ShaderModule::drop {shader_module_id:?}");

wgpu-core/src/device/resource.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,45 @@ impl Device {
17291729
Ok(module)
17301730
}
17311731

1732+
#[allow(unused_unsafe)]
1733+
pub(crate) unsafe fn create_shader_module_msl<'a>(
1734+
self: &Arc<Self>,
1735+
desc: &pipeline::ShaderModuleDescriptor<'a>,
1736+
source: &'a Cow<str>,
1737+
entry_point: &str,
1738+
num_workgroups: (u32, u32, u32),
1739+
) -> Result<Arc<pipeline::ShaderModule>, pipeline::CreateShaderModuleError> {
1740+
let hal_shader = hal::ShaderInput::Msl {
1741+
shader: source.to_string(),
1742+
entry_point: entry_point.to_string(),
1743+
num_workgroups,
1744+
};
1745+
let hal_desc = hal::ShaderModuleDescriptor {
1746+
label: desc.label.to_hal(self.instance_flags),
1747+
runtime_checks: desc.runtime_checks,
1748+
};
1749+
let raw =
1750+
unsafe { self.raw().create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
1751+
match error {
1752+
hal::ShaderError::Device(err) => {
1753+
pipeline::CreateShaderModuleError::Device(self.handle_hal_error(err))
1754+
}
1755+
hal::ShaderError::Compilation(msg) => {
1756+
log::error!("Shader compilation error: {}", msg);
1757+
pipeline::CreateShaderModuleError::Generation
1758+
}
1759+
}
1760+
})?;
1761+
let module = pipeline::ShaderModule {
1762+
raw: ManuallyDrop::new(raw),
1763+
device: self.clone(),
1764+
interface: None,
1765+
label: desc.label.to_string(),
1766+
};
1767+
1768+
Ok(Arc::new(module))
1769+
}
1770+
17321771
pub(crate) fn create_command_encoder(
17331772
self: &Arc<Self>,
17341773
label: &crate::Label,

wgpu-hal/src/dx12/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,9 @@ impl crate::Device for super::Device {
16881688
}),
16891689
crate::ShaderInput::SpirV(_) => {
16901690
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
1691+
},
1692+
crate::ShaderInput::Msl { .. } => {
1693+
panic!("MLS_SHADER_PASSTHROUGH is not enabled for this backend")
16911694
}
16921695
}
16931696
}

wgpu-hal/src/gles/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,9 @@ impl crate::Device for super::Device {
13271327
crate::ShaderInput::SpirV(_) => {
13281328
panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled")
13291329
}
1330+
crate::ShaderInput::Msl { .. } => {
1331+
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
1332+
}
13301333
crate::ShaderInput::Naga(naga) => naga,
13311334
},
13321335
label: desc.label.map(|str| str.to_string()),

wgpu-hal/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,7 @@ pub struct CommandEncoderDescriptor<'a, Q: DynQueue + ?Sized> {
20542054
}
20552055

20562056
/// Naga shader module.
2057+
#[derive(Default)]
20572058
pub struct NagaShader {
20582059
/// Shader module IR.
20592060
pub module: Cow<'static, naga::Module>,
@@ -2075,6 +2076,11 @@ impl fmt::Debug for NagaShader {
20752076
#[allow(clippy::large_enum_variant)]
20762077
pub enum ShaderInput<'a> {
20772078
Naga(NagaShader),
2079+
Msl {
2080+
shader: String,
2081+
entry_point: String,
2082+
num_workgroups: (u32, u32, u32),
2083+
},
20782084
SpirV(&'a [u32]),
20792085
}
20802086

wgpu-hal/src/metal/device.rs

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,38 @@ impl crate::Device for super::Device {
987987
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
988988
naga,
989989
bounds_checks: desc.runtime_checks,
990+
library: None,
991+
function: None,
992+
entry_point: None,
993+
num_workgroups: None,
990994
}),
995+
crate::ShaderInput::Msl {
996+
shader: source,
997+
entry_point,
998+
num_workgroups,
999+
} => {
1000+
let options = metal::CompileOptions::new();
1001+
// Obtain the locked device from shared
1002+
let device = self.shared.device.lock();
1003+
let library = device
1004+
.new_library_with_source(&source, &options)
1005+
.map_err(|e| crate::ShaderError::Compilation(format!("MSL: {:?}", e)))?;
1006+
let function = library.get_function(&entry_point, None).map_err(|_| {
1007+
crate::ShaderError::Compilation(format!(
1008+
"Entry point '{}' not found",
1009+
entry_point
1010+
))
1011+
})?;
1012+
1013+
Ok(super::ShaderModule {
1014+
naga: crate::NagaShader::default(), // naga modules is not used for passthrough
1015+
library: Some(library),
1016+
function: Some(function),
1017+
entry_point: Some(entry_point),
1018+
num_workgroups: Some(num_workgroups),
1019+
bounds_checks: desc.runtime_checks,
1020+
})
1021+
}
9911022
crate::ShaderInput::SpirV(_) => {
9921023
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
9931024
}
@@ -1295,13 +1326,34 @@ impl crate::Device for super::Device {
12951326
objc::rc::autoreleasepool(|| {
12961327
let descriptor = metal::ComputePipelineDescriptor::new();
12971328

1298-
let cs = self.load_shader(
1299-
&desc.stage,
1300-
&[],
1301-
desc.layout,
1302-
metal::MTLPrimitiveTopologyClass::Unspecified,
1303-
naga::ShaderStage::Compute,
1304-
)?;
1329+
let module = desc.stage.module;
1330+
let cs = if module.function.is_some()
1331+
&& module.library.is_some()
1332+
&& module.num_workgroups.is_some()
1333+
{
1334+
let wg_size = module.num_workgroups.unwrap();
1335+
CompiledShader {
1336+
library: module.library.clone().unwrap(),
1337+
function: module.function.clone().unwrap(),
1338+
wg_size: metal::MTLSize::new(
1339+
wg_size.0 as u64,
1340+
wg_size.1 as u64,
1341+
wg_size.2 as u64,
1342+
),
1343+
wg_memory_sizes: vec![],
1344+
sized_bindings: vec![],
1345+
immutable_buffer_mask: 0,
1346+
}
1347+
} else {
1348+
self.load_shader(
1349+
&desc.stage,
1350+
&[],
1351+
desc.layout,
1352+
metal::MTLPrimitiveTopologyClass::Unspecified,
1353+
naga::ShaderStage::Compute,
1354+
)?
1355+
};
1356+
13051357
descriptor.set_compute_function(Some(&cs.function));
13061358

13071359
if self.shared.private_caps.supports_mutability {

wgpu-hal/src/metal/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,10 @@ unsafe impl Sync for BindGroup {}
784784
pub struct ShaderModule {
785785
naga: crate::NagaShader,
786786
bounds_checks: wgt::ShaderRuntimeChecks,
787+
pub library: Option<metal::Library>,
788+
pub function: Option<metal::Function>,
789+
pub entry_point: Option<String>,
790+
pub num_workgroups: Option<(u32, u32, u32)>,
787791
}
788792

789793
impl crate::DynShaderModule for ShaderModule {}

wgpu-hal/src/vulkan/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,6 +1862,9 @@ impl crate::Device for super::Device {
18621862
.map_err(|e| crate::ShaderError::Compilation(format!("{e}")))?,
18631863
)
18641864
}
1865+
crate::ShaderInput::Msl { .. } => {
1866+
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
1867+
}
18651868
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
18661869
};
18671870

wgpu/src/api/device.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ impl Device {
155155
ShaderModule { inner: module }
156156
}
157157

158+
/// Creates a shader module from Metal MSL shader directly.
159+
///
160+
/// # Safety
161+
///
162+
/// This function passes the source to the backend as-is and can potentially result in a
163+
/// driver crash or bogus behaviour. No attempt is made to ensure that source code is valid.
164+
#[must_use]
165+
pub unsafe fn create_shader_module_msl(
166+
&self,
167+
desc: &ShaderModuleDescriptorMsl<'_>,
168+
) -> ShaderModule {
169+
let module = unsafe { self.inner.create_shader_module_msl(desc) };
170+
ShaderModule { inner: module }
171+
}
172+
158173
/// Creates an empty [`CommandEncoder`].
159174
#[must_use]
160175
pub fn create_command_encoder(&self, desc: &CommandEncoderDescriptor<'_>) -> CommandEncoder {

wgpu/src/api/shader_module.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,21 @@ pub struct ShaderModuleDescriptorSpirV<'a> {
236236
pub source: Cow<'a, [u32]>,
237237
}
238238
static_assertions::assert_impl_all!(ShaderModuleDescriptorSpirV<'_>: Send, Sync);
239+
240+
/// Descriptor for a shader module given by Metal MSL source, for use with
241+
/// [`Device::create_shader_module_msl`].
242+
///
243+
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
244+
/// only WGSL source code strings are accepted.
245+
#[derive(Debug)]
246+
pub struct ShaderModuleDescriptorMsl<'a> {
247+
/// Entrypoint.
248+
pub entry_point: String,
249+
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
250+
pub label: Label<'a>,
251+
/// Number of workgroups in each dimension x, y and z.
252+
pub num_workgroups: (u32, u32, u32),
253+
/// Shader MSL source.
254+
pub source: Cow<'a, str>,
255+
}
256+
static_assertions::assert_impl_all!(ShaderModuleDescriptorMsl<'_>: Send, Sync);

wgpu/src/backend/webgpu.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,6 +1838,13 @@ impl dispatch::DeviceInterface for WebDevice {
18381838
unreachable!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
18391839
}
18401840

1841+
unsafe fn create_shader_module_msl(
1842+
&self,
1843+
_desc: &crate::ShaderModuleDescriptorMsl<'_>,
1844+
) -> dispatch::DispatchShaderModule {
1845+
unreachable!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
1846+
}
1847+
18411848
fn create_bind_group_layout(
18421849
&self,
18431850
desc: &crate::BindGroupLayoutDescriptor<'_>,

wgpu/src/backend/wgpu_core.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,46 @@ impl dispatch::DeviceInterface for CoreDevice {
10841084
.into()
10851085
}
10861086

1087+
unsafe fn create_shader_module_msl(
1088+
&self,
1089+
desc: &crate::ShaderModuleDescriptorMsl<'_>,
1090+
) -> dispatch::DispatchShaderModule {
1091+
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
1092+
label: desc.label.map(Borrowed),
1093+
// Doesn't matter the value since msl passthrough shaders aren't mutated to include
1094+
// runtime checks
1095+
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
1096+
};
1097+
let (id, error) = unsafe {
1098+
self.context.0.device_create_shader_module_msl(
1099+
self.id,
1100+
&descriptor,
1101+
Borrowed(&desc.source),
1102+
&desc.entry_point,
1103+
desc.num_workgroups,
1104+
None,
1105+
)
1106+
};
1107+
let compilation_info = match error {
1108+
Some(cause) => {
1109+
self.context.handle_error(
1110+
&self.error_sink,
1111+
cause.clone(),
1112+
desc.label,
1113+
"Device::create_shader_module_msl",
1114+
);
1115+
CompilationInfo::from(cause)
1116+
}
1117+
None => CompilationInfo { messages: vec![] },
1118+
};
1119+
CoreShaderModule {
1120+
context: self.context.clone(),
1121+
id,
1122+
compilation_info,
1123+
}
1124+
.into()
1125+
}
1126+
10871127
fn create_bind_group_layout(
10881128
&self,
10891129
desc: &crate::BindGroupLayoutDescriptor<'_>,

wgpu/src/dispatch.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ pub trait DeviceInterface: CommonTraits {
115115
&self,
116116
desc: &crate::ShaderModuleDescriptorSpirV<'_>,
117117
) -> DispatchShaderModule;
118+
unsafe fn create_shader_module_msl(
119+
&self,
120+
desc: &crate::ShaderModuleDescriptorMsl<'_>,
121+
) -> DispatchShaderModule;
118122
fn create_bind_group_layout(
119123
&self,
120124
desc: &crate::BindGroupLayoutDescriptor<'_>,

0 commit comments

Comments
 (0)