Skip to content

Matmul after reshape crashes on Metal #2737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
laptou opened this issue Jan 24, 2025 · 1 comment
Open

Matmul after reshape crashes on Metal #2737

laptou opened this issue Jan 24, 2025 · 1 comment

Comments

@laptou
Copy link
Contributor

laptou commented Jan 24, 2025

Doing a matrix multiplication after transpose() calls / after permute() crashes on Metal backend.
Here's a reproduction of the bug:

use candle_core::{DType, Device, Tensor};

pub fn main() -> candle_core::Result<()> {
    let dev = Device::Cpu;
    let x = Tensor::zeros(&[1, 8, 52, 1250], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 8, 1250, 52], DType::F32, &dev)?;
    let z = x.matmul(&y)?; // works fine
    dbg!(z.shape());

    let x = Tensor::zeros(&[1, 8, 1250, 52], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 8, 52, 1250], DType::F32, &dev)?;
    let z = x.matmul(&y)?; // works fine
    dbg!(z.shape());

    let x = Tensor::zeros(&[1, 1250, 416], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 1250, 416], DType::F32, &dev)?;

    let x = x.reshape((1, 1250, 8, 52))?.transpose(1, 2)?;
    let y = y
        .reshape((1, 1250, 8, 52))?
        .transpose(1, 2)?
        .transpose(2, 3)?;

    let z = x.matmul(&y)?; // works fine
    dbg!(z.shape());

    let dev = Device::new_metal(0)?;
    let x = Tensor::zeros(&[1, 8, 52, 1250], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 8, 1250, 52], DType::F32, &dev)?;
    let z = x.matmul(&y)?; // works fine
    dbg!(z.shape());

    let x = Tensor::zeros(&[1, 8, 1250, 52], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 8, 52, 1250], DType::F32, &dev)?;
    let z = x.matmul(&y)?; // works fine
    dbg!(z.shape());

    let x = Tensor::zeros(&[1, 1250, 416], DType::F32, &dev)?;
    let y = Tensor::zeros(&[1, 1250, 416], DType::F32, &dev)?;

    let x = x.reshape((1, 1250, 8, 52))?.transpose(1, 2)?;
    let y = y
        .reshape((1, 1250, 8, 52))?
        .transpose(1, 2)?
        .transpose(2, 3)?;
      // also crashes if you do permute(0, 2, 3, 1)
      // also crashes if you create the array with shape (1, 1250, 8, 52) without reshaping

    let z = x.matmul(&y)?; // crashes
    dbg!(z.shape());

    Ok(())
}
@laptou
Copy link
Contributor Author

laptou commented Jan 24, 2025

The error message says "Invalid matmul arguments" which is not very clear. The error in the code is called MatMulNonContiguous. The error goes away if I call .contiguous() on x and y before calling matmul. Can I send a PR to update this error message to be clearer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant