Skip to content

[metal] Metal compute shader passthrough #7326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,32 @@ layout(location = 0, index = 1) out vec4 output1;

By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)

#### Unify interface for SpirV shader passthrough

Replace device `create_shader_module_spirv` function with a generic `create_shader_module_passthrough` function
taking a `ShaderModuleDescriptorPassthrough` enum as parameter.

Update your calls to `create_shader_module_spirv` and use `create_shader_module_passthrough` instead:

```diff
- device.create_shader_module_spirv(
- wgpu::ShaderModuleDescriptorSpirV {
- label: Some(&name),
- source: Cow::Borrowed(&source),
- }
- )
+ device.create_shader_module_passthrough(
+ wgpu::ShaderModuleDescriptorPassthrough::SpirV(
+ wgpu::ShaderModuleDescriptorSpirV {
+ label: Some(&name),
+ source: Cow::Borrowed(&source),
+ },
+ ),
+ )
```

By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).

### New Features

- Added mesh shader support to `wgpu_hal`. By @SupaMaggie70Incorporated in [#7089](https://github.com/gfx-rs/wgpu/pull/7089)
Expand All @@ -203,6 +229,9 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)

- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) .

- Add Metal compute shader passthrough. Use `create_shader_module_passthrough` on device. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).

- new `Features::MSL_SHADER_PASSTHROUGH` run-time feature allows providing pass-through MSL Metal shaders. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).

#### Naga

Expand Down
3 changes: 3 additions & 0 deletions deno_webgpu/webidl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ pub enum GPUFeatureName {
VertexWritableStorage,
#[webidl(rename = "clear-texture")]
ClearTexture,
#[webidl(rename = "msl-shader-passthrough")]
MslShaderPassthrough,
#[webidl(rename = "spirv-shader-passthrough")]
SpirvShaderPassthrough,
#[webidl(rename = "multiview")]
Expand Down Expand Up @@ -477,6 +479,7 @@ pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Feat
GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION,
GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE,
GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE,
GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH,
GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH,
GPUFeatureName::Multiview => Features::MULTIVIEW,
GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT,
Expand Down
4 changes: 2 additions & 2 deletions examples/standalone/custom_backend/src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ impl DeviceInterface for CustomDevice {
DispatchShaderModule::custom(CustomShaderModule(self.0.clone()))
}

unsafe fn create_shader_module_spirv(
unsafe fn create_shader_module_passthrough(
&self,
_desc: &wgpu::ShaderModuleDescriptorSpirV<'_>,
_desc: &wgpu::ShaderModuleDescriptorPassthrough<'_>,
) -> DispatchShaderModule {
unimplemented!()
}
Expand Down
2 changes: 1 addition & 1 deletion naga/src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ bitflags::bitflags! {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
Expand Down
14 changes: 8 additions & 6 deletions tests/tests/wgpu-gpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,14 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
fail(
&ctx.device,
|| unsafe {
let _ = ctx
.device
.create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: std::borrow::Cow::Borrowed(&[]),
});
let _ = ctx.device.create_shader_module_passthrough(
wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: std::borrow::Cow::Borrowed(&[]),
},
),
);
},
Some("device with '' label is invalid"),
);
Expand Down
33 changes: 23 additions & 10 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -938,23 +938,21 @@ impl Global {
(id, Some(error))
}

// Unsafe-ness of internal calls has little to do with unsafe-ness of this.
#[allow(unused_unsafe)]
/// # Safety
///
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
/// This function passes source code or binary to the backend as-is and can potentially result in a
/// driver crash.
pub unsafe fn device_create_shader_module_spirv(
pub unsafe fn device_create_shader_module_passthrough(
&self,
device_id: DeviceId,
desc: &pipeline::ShaderModuleDescriptor,
source: Cow<[u32]>,
desc: &pipeline::ShaderModuleDescriptorPassthrough<'_>,
id_in: Option<id::ShaderModuleId>,
) -> (
id::ShaderModuleId,
Option<pipeline::CreateShaderModuleError>,
) {
profiling::scope!("Device::create_shader_module");
profiling::scope!("Device::create_shader_module_passthrough");

let hub = &self.hub;
let fid = hub.shader_modules.prepare(id_in);
Expand All @@ -964,15 +962,30 @@ impl Global {

#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
let data = trace.make_binary("spv", bytemuck::cast_slice(&source));
let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data());
trace.add(trace::Action::CreateShaderModule {
id: fid.id(),
desc: desc.clone(),
desc: match desc {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
});
};

let shader = match unsafe { device.create_shader_module_spirv(desc, &source) } {
let result = unsafe { device.create_shader_module_passthrough(desc) };

let shader = match result {
Ok(shader) => shader,
Err(e) => break 'error e,
};
Expand All @@ -981,7 +994,7 @@ impl Global {
return (id, None);
};

let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string())));
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string())));
(id, Some(error))
}

Expand Down
32 changes: 21 additions & 11 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1750,19 +1750,31 @@ impl Device {
}

#[allow(unused_unsafe)]
pub(crate) unsafe fn create_shader_module_spirv<'a>(
pub(crate) unsafe fn create_shader_module_passthrough<'a>(
self: &Arc<Self>,
desc: &pipeline::ShaderModuleDescriptor<'a>,
source: &'a [u32],
descriptor: &pipeline::ShaderModuleDescriptorPassthrough<'a>,
) -> Result<Arc<pipeline::ShaderModule>, pipeline::CreateShaderModuleError> {
self.check_is_valid()?;
let hal_shader = match descriptor {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
hal::ShaderInput::SpirV(&inner.source)
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Msl {
shader: inner.source.to_string(),
entry_point: inner.entry_point.to_string(),
num_workgroups: inner.num_workgroups,
}
}
};

self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.to_hal(self.instance_flags),
runtime_checks: desc.runtime_checks,
label: descriptor.label().to_hal(self.instance_flags),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
let hal_shader = hal::ShaderInput::SpirV(source);

let raw = match unsafe { self.raw().create_shader_module(&hal_desc, hal_shader) } {
Ok(raw) => raw,
Err(error) => {
Expand All @@ -1782,12 +1794,10 @@ impl Device {
raw: ManuallyDrop::new(raw),
device: self.clone(),
interface: None,
label: desc.label.to_string(),
label: descriptor.label().to_string(),
};

let module = Arc::new(module);

Ok(module)
Ok(Arc::new(module))
}

pub(crate) fn create_command_encoder(
Expand Down
3 changes: 3 additions & 0 deletions wgpu-core/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pub struct ShaderModuleDescriptor<'a> {
pub runtime_checks: wgt::ShaderRuntimeChecks,
}

pub type ShaderModuleDescriptorPassthrough<'a> =
wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>;

#[derive(Debug)]
pub struct ShaderModule {
pub(crate) raw: ManuallyDrop<Box<dyn hal::DynShaderModule>>,
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/gles/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Msl { .. } => {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Naga(naga) => naga,
},
label: desc.label.map(|str| str.to_string()),
Expand Down
6 changes: 6 additions & 0 deletions wgpu-hal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,7 @@ pub struct CommandEncoderDescriptor<'a, Q: DynQueue + ?Sized> {
}

/// Naga shader module.
#[derive(Default)]
pub struct NagaShader {
/// Shader module IR.
pub module: Cow<'static, naga::Module>,
Expand All @@ -2090,6 +2091,11 @@ impl fmt::Debug for NagaShader {
#[allow(clippy::large_enum_variant)]
pub enum ShaderInput<'a> {
Naga(NagaShader),
Msl {
shader: String,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
SpirV(&'a [u32]),
}

Expand Down
1 change: 1 addition & 0 deletions wgpu-hal/src/metal/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ impl super::PrivateCapabilities {
use wgt::Features as F;

let mut features = F::empty()
| F::MSL_SHADER_PASSTHROUGH
| F::MAPPABLE_PRIMARY_BUFFERS
| F::VERTEX_WRITABLE_STORAGE
| F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES
Expand Down
74 changes: 62 additions & 12 deletions wgpu-hal/src/metal/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use std::{thread, time};

use parking_lot::Mutex;

use super::conv;
use super::{conv, PassthroughShader};
use crate::auxil::map_naga_stage;
use crate::metal::ShaderModuleSource;
use crate::TlasInstance;

use metal::foreign_types::ForeignType;
Expand Down Expand Up @@ -122,11 +123,15 @@ impl super::Device {
primitive_class: metal::MTLPrimitiveTopologyClass,
naga_stage: naga::ShaderStage,
) -> Result<CompiledShader, crate::PipelineError> {
let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source {
naga
} else {
panic!("load_shader required a naga shader");
};
let stage_bit = map_naga_stage(naga_stage);

let (module, module_info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
&naga_shader.module,
&naga_shader.info,
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {:?}", e)))?;
Expand Down Expand Up @@ -989,9 +994,37 @@ impl crate::Device for super::Device {

match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga,
source: ShaderModuleSource::Naga(naga),
bounds_checks: desc.runtime_checks,
}),
crate::ShaderInput::Msl {
shader: source,
entry_point,
num_workgroups,
} => {
let options = metal::CompileOptions::new();
// Obtain the locked device from shared
let device = self.shared.device.lock();
let library = device
.new_library_with_source(&source, &options)
.map_err(|e| crate::ShaderError::Compilation(format!("MSL: {:?}", e)))?;
let function = library.get_function(&entry_point, None).map_err(|_| {
crate::ShaderError::Compilation(format!(
"Entry point '{}' not found",
entry_point
))
})?;

Ok(super::ShaderModule {
source: ShaderModuleSource::Passthrough(PassthroughShader {
library,
function,
entry_point,
num_workgroups,
}),
bounds_checks: desc.runtime_checks,
})
}
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
Expand Down Expand Up @@ -1299,13 +1332,30 @@ impl crate::Device for super::Device {
objc::rc::autoreleasepool(|| {
let descriptor = metal::ComputePipelineDescriptor::new();

let cs = self.load_shader(
&desc.stage,
&[],
desc.layout,
metal::MTLPrimitiveTopologyClass::Unspecified,
naga::ShaderStage::Compute,
)?;
let module = desc.stage.module;
let cs = if let ShaderModuleSource::Passthrough(desc) = &module.source {
CompiledShader {
library: desc.library.clone(),
function: desc.function.clone(),
wg_size: metal::MTLSize::new(
desc.num_workgroups.0 as u64,
desc.num_workgroups.1 as u64,
desc.num_workgroups.2 as u64,
),
wg_memory_sizes: vec![],
sized_bindings: vec![],
immutable_buffer_mask: 0,
}
} else {
self.load_shader(
&desc.stage,
&[],
desc.layout,
metal::MTLPrimitiveTopologyClass::Unspecified,
naga::ShaderStage::Compute,
)?
};

descriptor.set_compute_function(Some(&cs.function));

if self.shared.private_caps.supports_mutability {
Expand Down
Loading