Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Metal compute shader passthrough #7326

Open
wants to merge 4 commits into
base: trunk
Choose a base branch
from

Conversation

syl20bnr
Copy link

@syl20bnr syl20bnr commented Mar 12, 2025

Description
This PR adds the possibility to directly pass a metal shader and create a compute pipeline to execute the passed shader without any validation. This allows CubeCL to use its wgpu runtime to execute metal kernels directly compiled by CubeCL JIT compiler and benefit from support for advanced metal features.

Testing
Tested with a WIP metal compiler using the wgpu runtime of CubeCL.
More info about how to test this soon.

Squash or Rebase?

Ok to be squashed.

Checklist

  • Run cargo fmt.
  • Run taplo format.
  • Run cargo clippy. If applicable, add:
    • --target wasm32-unknown-unknown
  • Run cargo xtask test to run tests.
  • If this contains user-facing changes, add a CHANGELOG.md entry.

@syl20bnr syl20bnr requested a review from a team as a code owner March 12, 2025 22:42
@syl20bnr syl20bnr marked this pull request as draft March 12, 2025 22:43
@syl20bnr syl20bnr force-pushed the feat/metal-passthrough branch 3 times, most recently from edef8fa to c049099 Compare March 12, 2025 23:04
Comment on lines 240 to 254
/// Descriptor for a shader module given by Metal MSL sourc, for use with
/// [`Device::create_shader_module_msl`].
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to get some description of the shader ABI we present to shaders and how things get translated. It doesn't need to be as strict as in wgpu-hal, but should guide people on what wgpu expects the shader to have.

@syl20bnr syl20bnr force-pushed the feat/metal-passthrough branch 2 times, most recently from fe86233 to 17b472e Compare March 12, 2025 23:21
@DJMcNab
Copy link
Contributor

DJMcNab commented Mar 13, 2025

(Linking to #3103)

@syl20bnr
Copy link
Author

@cwfitzgerald Thank you for the review.

I think we should unify this entrypoint with create_shader_module_spirv and call it create_shader_module_passthrough which takes a ShaderModulePassthrough enum with each descriptor as needed.

I pushed a simple change for this, is that what you had in mind ?

It would be great to get some description of the shader ABI we present to shaders and how things get translated. It doesn't need to be as strict as in wgpu-hal, but should guide people on what wgpu expects the shader to have.

Not sure what do you mean, do you have an example of such description ?

Also can you point me to where I could add a test for the metal passthrough if needed ?

@syl20bnr syl20bnr force-pushed the feat/metal-passthrough branch from 5a59372 to 87006a1 Compare March 25, 2025 13:30
@syl20bnr syl20bnr force-pushed the feat/metal-passthrough branch from 87006a1 to 624db2a Compare March 25, 2025 13:34
@cwfitzgerald cwfitzgerald marked this pull request as ready for review April 1, 2025 16:55
@cwfitzgerald cwfitzgerald self-assigned this Apr 2, 2025
})?;

Ok(super::ShaderModule {
naga: crate::NagaShader::default(), // naga modules is not used for passthrough
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be optional, instead of adding on the default

Comment on lines 786 to +791
naga: crate::NagaShader,
bounds_checks: wgt::ShaderRuntimeChecks,
pub library: Option<metal::Library>,
pub function: Option<metal::Function>,
pub entry_point: Option<String>,
pub num_workgroups: Option<(u32, u32, u32)>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add a nested enum ShaderModuleSource or so, which has Naga and Passthrough variants, a decent amount of this code will simplify I think.

@@ -237,7 +237,7 @@ bitflags::bitflags! {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this in general, as you should need to validate a module to get a module info.

@cwfitzgerald
Copy link
Member

cwfitzgerald commented Apr 2, 2025

Not sure what do you mean, do you have an example of such description ?

Take a look at https://wiki.libsdl.org/SDL3/SDL_CreateGPUShader#remarks which talks about the order that SDL requires the shader bindings to be, depending on the backend. Ideally we would have something like this for MSL as its not always straight forward, but we can punt this to another PR if it would be too much hassle for now.

Also can you point me to where I could add a test for the metal passthrough if needed ?

Make sure you update the PR to latest, but then tests/tests/wgpu-gpu has gpu-enabled tests. You can write a test which is skipped on all platforms other that metal, such as

.skip(FailureCase::backend(Backends::all() - Backends::DX12)),
defines a test which operates on all platforms except DX12.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants