You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Doing a matrix multiplication after transpose() calls / after permute() crashes on Metal backend.
Here's a reproduction of the bug:
use candle_core::{DType,Device,Tensor};pubfnmain() -> candle_core::Result<()>{let dev = Device::Cpu;let x = Tensor::zeros(&[1,8,52,1250],DType::F32,&dev)?;let y = Tensor::zeros(&[1,8,1250,52],DType::F32,&dev)?;let z = x.matmul(&y)?;// works finedbg!(z.shape());let x = Tensor::zeros(&[1,8,1250,52],DType::F32,&dev)?;let y = Tensor::zeros(&[1,8,52,1250],DType::F32,&dev)?;let z = x.matmul(&y)?;// works finedbg!(z.shape());let x = Tensor::zeros(&[1,1250,416],DType::F32,&dev)?;let y = Tensor::zeros(&[1,1250,416],DType::F32,&dev)?;let x = x.reshape((1,1250,8,52))?.transpose(1,2)?;let y = y
.reshape((1,1250,8,52))?
.transpose(1,2)?
.transpose(2,3)?;let z = x.matmul(&y)?;// works finedbg!(z.shape());let dev = Device::new_metal(0)?;let x = Tensor::zeros(&[1,8,52,1250],DType::F32,&dev)?;let y = Tensor::zeros(&[1,8,1250,52],DType::F32,&dev)?;let z = x.matmul(&y)?;// works finedbg!(z.shape());let x = Tensor::zeros(&[1,8,1250,52],DType::F32,&dev)?;let y = Tensor::zeros(&[1,8,52,1250],DType::F32,&dev)?;let z = x.matmul(&y)?;// works finedbg!(z.shape());let x = Tensor::zeros(&[1,1250,416],DType::F32,&dev)?;let y = Tensor::zeros(&[1,1250,416],DType::F32,&dev)?;let x = x.reshape((1,1250,8,52))?.transpose(1,2)?;let y = y
.reshape((1,1250,8,52))?
.transpose(1,2)?
.transpose(2,3)?;// also crashes if you do permute(0, 2, 3, 1)// also crashes if you create the array with shape (1, 1250, 8, 52) without reshapinglet z = x.matmul(&y)?;// crashesdbg!(z.shape());Ok(())}
The text was updated successfully, but these errors were encountered:
The error message says "Invalid matmul arguments" which is not very clear. The error in the code is called MatMulNonContiguous. The error goes away if I call .contiguous() on x and y before calling matmul. Can I send a PR to update this error message to be clearer?
Doing a matrix multiplication after
transpose()
calls / afterpermute()
crashes on Metal backend.Here's a reproduction of the bug:
The text was updated successfully, but these errors were encountered: