Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge reduce #2673

Merged
merged 14 commits into from
Jan 13, 2025
480 changes: 248 additions & 232 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features =
"cubecl",
"repr",
] }
cubecl = { workspace = true, features = ["linalg"] }
cubecl = { workspace = true, features = ["linalg", "reduce"] }

bytemuck = { workspace = true }
derive-new = { workspace = true }
Expand Down
142 changes: 80 additions & 62 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,101 @@
use cubecl::prelude::Numeric;

#[cfg(feature = "autotune")]
use crate::kernel::reduce::reduce_dim_autotune;
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};

use super::{
naive::{base::ReduceDimNaiveFamily, kernel::reduce_dim_naive},
shared::{base::ReduceDimShared, kernel::reduce_dim_shared},
subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube},
};
use super::autotune_reduce;

pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};

#[allow(dead_code)]
pub(crate) trait ReduceDimAlgorithm<EI: Numeric, EO: Numeric>:
core::fmt::Debug + ReduceDimNaiveFamily + ReduceDimShared<EI, EO> + ReduceDimSubcube<EI, EO>
{
/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
mut input: JitTensor<Run>,
strategy: ReduceStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
input.shape = input.shape.flatten();
input.strides = vec![1];
reduce_dim::<Run, In, Out, Rd>(input, 0, strategy)
}

/// Creates an empty output tensor with reduce output shape
pub fn init_reduce_output<R: JitRuntime, EI: JitElement, EO: JitElement>(
input: &JitTensor<R>,
reduce_dim: usize,
) -> JitTensor<R> {
let mut shape_out = input.shape.clone();
shape_out.dims[reduce_dim] = 1;
/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce_dim<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
input: JitTensor<Run>,
dim: usize,
strategy: ReduceStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
let client = input.client.clone();
let output = init_reduce_output::<Run, In, Out>(&input, dim).ok_or(
cubecl::reduce::ReduceError::InvalidAxis {
axis: dim,
rank: input.shape.num_dims(),
},
)?;
let result = match strategy {
ReduceStrategy::Unspecified => cubecl::reduce::reduce::<Run, In, Out, Rd>(
&client,
input.as_handle_ref(),
output.as_handle_ref(),
dim,
None,
),
ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::<Run, In, Out, Rd>(
&client,
input.as_handle_ref(),
output.as_handle_ref(),
dim,
Some(strategy),
),
#[cfg(feature = "autotune")]
ReduceStrategy::Autotune => {
autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim)
}
};
result.map(|_| output)
}

empty_device::<R, EO>(input.client.clone(), input.device.clone(), shape_out)
/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
/// or return `None` if `axis` is out-of-bound.
pub fn init_reduce_output<Run: JitRuntime, In: JitElement, Out: JitElement>(
input: &JitTensor<Run>,
dim: usize,
) -> Option<JitTensor<Run>> {
(dim < input.shape.num_dims()).then(|| {
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
empty_device::<Run, Out>(input.client.clone(), input.device.clone(), shape_out)
})
}

/// Select a strategy to perform a reduction.
#[derive(Copy, Clone, Debug)]
#[allow(missing_docs)]
pub enum ReduceStrategy {
/// Naive
Naive,
/// Use shared memory as an accumulator
SharedMemory,
/// Use subcube functions
Subcube,
/// Use a best-effort strategy based on the hardware capacity.
/// This differs from Autotune as it doesn't try and compare many strategies to select the best.
Unspecified,
/// Fix the exact strategy for the reduction.
Specific(cubecl::reduce::ReduceStrategy),
/// Use autotune to find the best strategy given the hardware and the inputs.
#[cfg(feature = "autotune")]
Autotune,
}

impl Default for ReduceStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return ReduceStrategy::Autotune;
return Self::Autotune;

#[cfg(not(feature = "autotune"))]
ReduceStrategy::Naive
return Self::Unspecified;
}
}

macro_rules! reduce_operation {
($name:ident, $ops:ident) => {
#[derive(Debug)]
pub(crate) struct $ops;

impl<EI: Numeric, EO: Numeric> ReduceDimAlgorithm<EI, EO> for $ops {}

/// Executes the reduce operation with the given strategy.
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement>(
tensor: JitTensor<R>,
dim: usize,
strategy: ReduceStrategy,
) -> Result<JitTensor<R>, String> {
match strategy {
ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim),
ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim),
ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim),
#[cfg(feature = "autotune")]
ReduceStrategy::Autotune => Ok(reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim)),
}
}
};
}

// Autotunable reduce operation variants
reduce_operation!(sum_dim, SumDim);
reduce_operation!(mean_dim, MeanDim);
reduce_operation!(prod_dim, ProdDim);
reduce_operation!(argmin, Argmin);
reduce_operation!(argmax, Argmax);
7 changes: 0 additions & 7 deletions crates/burn-jit/src/kernel/reduce/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
mod base;
mod naive;
mod prod;
mod shared;
mod subcube;
mod sum;
mod tune;

pub use base::*;
pub use prod::*;
pub use sum::*;
pub use tune::*;
36 changes: 0 additions & 36 deletions crates/burn-jit/src/kernel/reduce/naive/argmax.rs

This file was deleted.

36 changes: 0 additions & 36 deletions crates/burn-jit/src/kernel/reduce/naive/argmin.rs

This file was deleted.

25 changes: 0 additions & 25 deletions crates/burn-jit/src/kernel/reduce/naive/base.rs

This file was deleted.

71 changes: 0 additions & 71 deletions crates/burn-jit/src/kernel/reduce/naive/kernel.rs

This file was deleted.

27 changes: 0 additions & 27 deletions crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs

This file was deleted.

Loading
Loading