Skip to content

Commit b2182ee

Browse files
committed
add InstalledBackend struct
1 parent c8efe49 commit b2182ee

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

crates/cargo-gpu/src/build.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#![allow(clippy::unwrap_used, reason = "this is basically a test")]
33
//! `cargo gpu build`, analogous to `cargo build`
44
5+
use crate::install::Install;
56
use crate::linkage::Linkage;
67
use crate::lockfile::LockfileMismatchHandler;
7-
use crate::{install::Install, target_spec_dir};
88
use anyhow::Context as _;
99
use spirv_builder::{CompileResult, ModuleResult, SpirvBuilder};
1010
use std::io::Write as _;
@@ -59,22 +59,17 @@ pub struct Build {
5959
impl Build {
6060
/// Entrypoint
6161
pub fn run(&mut self) -> anyhow::Result<()> {
62-
let (rustc_codegen_spirv_location, toolchain_channel) = self.install.run()?;
62+
let installed_backend = self.install.run()?;
6363

6464
let _lockfile_mismatch_handler = LockfileMismatchHandler::new(
6565
&self.install.shader_crate,
66-
&toolchain_channel,
66+
&installed_backend.toolchain_channel,
6767
self.install.force_overwrite_lockfiles_v4_to_v3,
6868
)?;
6969

7070
let builder = &mut self.build.spirv_builder;
71-
builder.rustc_codegen_spirv_location = Some(rustc_codegen_spirv_location);
72-
builder.toolchain_overwrite = Some(toolchain_channel);
7371
builder.path_to_crate = Some(self.install.shader_crate.clone());
74-
builder.path_to_target_spec = Some(target_spec_dir()?.join(format!(
75-
"{}.json",
76-
builder.target.as_ref().context("expect target to be set")?
77-
)));
72+
installed_backend.configure_spirv_builder(builder)?;
7873

7974
// Ensure the shader output dir exists
8075
log::debug!(

crates/cargo-gpu/src/install.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::spirv_source::{
66
use crate::{cache_dir, spirv_source::SpirvSource, target_spec_dir};
77
use anyhow::Context as _;
88
use log::trace;
9-
use spirv_builder::TARGET_SPECS;
9+
use spirv_builder::{SpirvBuilder, TARGET_SPECS};
1010
use std::io::Write as _;
1111
use std::path::{Path, PathBuf};
1212

@@ -77,6 +77,40 @@ pub struct Install {
7777
pub force_overwrite_lockfiles_v4_to_v3: bool,
7878
}
7979

80+
/// Represents a functional backend installation, whether it was cached or just installed.
81+
#[derive(Clone, Debug)]
82+
pub struct InstalledBackend {
83+
/// path to the rustc_codegen_spirv dylib
84+
pub rustc_codegen_spirv_location: PathBuf,
85+
/// toolchain channel name
86+
pub toolchain_channel: String,
87+
}
88+
89+
impl InstalledBackend {
90+
/// Creates a new `SpirvBuilder` configured to use this installed backend.
91+
pub fn to_spirv_builder(
92+
&self,
93+
path_to_crate: impl AsRef<Path>,
94+
target: impl Into<String>,
95+
) -> SpirvBuilder {
96+
let mut builder = SpirvBuilder::new(path_to_crate, target);
97+
self.configure_spirv_builder(&mut builder)
98+
.expect("unreachable");
99+
builder
100+
}
101+
102+
/// Configures the supplied [`SpirvBuilder`]. `SpirvBuilder.target` must be set and must not change after calling this function.
103+
pub fn configure_spirv_builder(&self, builder: &mut SpirvBuilder) -> anyhow::Result<()> {
104+
builder.rustc_codegen_spirv_location = Some(self.rustc_codegen_spirv_location.clone());
105+
builder.toolchain_overwrite = Some(self.toolchain_channel.clone());
106+
builder.path_to_target_spec = Some(target_spec_dir()?.join(format!(
107+
"{}.json",
108+
builder.target.as_ref().context("expect target to be set")?
109+
)));
110+
Ok(())
111+
}
112+
}
113+
80114
impl Default for Install {
81115
#[inline]
82116
fn default() -> Self {
@@ -164,7 +198,7 @@ package = "rustc_codegen_spirv"
164198

165199
/// Install the binary pair and return the `(dylib_path, toolchain_channel)`.
166200
#[expect(clippy::too_many_lines, reason = "it's fine")]
167-
pub fn run(&mut self) -> anyhow::Result<(PathBuf, String)> {
201+
pub fn run(&mut self) -> anyhow::Result<InstalledBackend> {
168202
// Ensure the cache dir exists
169203
let cache_dir = cache_dir()?;
170204
log::info!("cache directory is '{}'", cache_dir.display());
@@ -282,6 +316,9 @@ package = "rustc_codegen_spirv"
282316
.context("writing target spec files")?;
283317
}
284318

285-
Ok((dest_dylib_path, toolchain_channel))
319+
Ok(InstalledBackend {
320+
rustc_codegen_spirv_location: dest_dylib_path,
321+
toolchain_channel,
322+
})
286323
}
287324
}

0 commit comments

Comments
 (0)