Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsupportedOperator: "Equal" (and "if"?) #19

Open
ronyfadel opened this issue Mar 13, 2024 · 5 comments
Open

unsupportedOperator: "Equal" (and "if"?) #19

ronyfadel opened this issue Mar 13, 2024 · 5 comments

Comments

@ronyfadel
Copy link

I'm trying to run inference on Silero VAD using MPSX (https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx) but it's failing with:

result: failure(MPSX.OnnxError.unsupportedOperator("Equal"))
@geor-kasapidi
Copy link
Contributor

please, look at this draft - #14 - use this as a reference. Maybe you will help and make a PR to MPSX :)

@ronyfadel
Copy link
Author

ronyfadel commented Mar 13, 2024

Thanks for the guidance @geor-kasapidi ! I'm trying to implement if but tbh this is all too new to me:

import MetalPerformanceShadersGraph

extension MPSGraph {
  /// https://github.com/onnx/onnx/blob/main/docs/Operators.md#If
  func `if`(
    _ node: Onnx_NodeProto,
    _ tensors: [String: MPSGraphTensor]
  ) throws -> MPSGraphTensor {
    guard let cond = tensors(node.input(0)),
          let else_branch = node.attr("else_branch"),
          let then_branch = node.attr("then_branch")
    else { throw OnnxError.invalidInput(node.name) }

    var error: Error?

    self.if(cond, then: {
      // do something with then_branch
    }, else: {
      // do something with else_branch
    }, name: nil)

    if let error {
      throw error
    } else {
      // return something
    }
  }
}

@geor-kasapidi your input is appreciated!

@geor-kasapidi
Copy link
Contributor

well, i took a look at the IF operator - this is kinda tricky one. Both branches, true and false, require subgraph creations - if i understand onnx spec correctly. This is not a one-line implementation and requires new logic with recursive graph creation. But you can do this: for both branches call onnx function in MPSX with local tables for tensors and return output tensors from onnx calls in MPSGraph if method closures:

self.onnx(node: <#T##Onnx_NodeProto#>,
                  optimizedForMPS: <#T##Bool#>,
                  tensorsDataType: <#T##MPSDataType#>,
                  tensors: &<#T##[String : MPSGraphTensor]#>,
                  constants: &<#T##[String : Onnx_TensorProto]#>)
self.if(<#T##predicateTensor: MPSGraphTensor##MPSGraphTensor#>, then: {
            // onnx result for true
        }, else: {
            // onnx result for false
        }, name: nil)

try this by yourself - and i will try to help more in case you will fail with this :)

@geor-kasapidi
Copy link
Contributor

@ronyfadel any updates here?

@ronyfadel
Copy link
Author

Hey @geor-kasapidi , I've paused working on this for a little while because of competing priorities

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants