Skip to content

Conversation

@rdspring1
Copy link
Collaborator

torch::jit::toIValue converts from py::handle to IValue. Then KernelArgumentHolder converts IValue to PolymorphicValue. The PyTorch function handles several types, beyond what NvFuser supports, resulting in unnecessary latency. toPolymorphicValue converts py::handle directly to PolymorphicValue to save latency.

  • This is a prerequisite for using nanobind.

tests/python/direct/test_python_frontend.py

function time
toPolymorphicValue 41.04s
torch::jit::toIValue 41.33s

All python tests --- 7% improvement

function time
toPolymorphicValue 317.45s
torch::jit::toIValue 340.92s

@rdspring1 rdspring1 requested a review from jjsjann123 January 7, 2026 17:47
@rdspring1 rdspring1 added the Direct Bindings Python extension with direct mapping to NvFuser CPP objects. label Jan 7, 2026
@rdspring1
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

Description

  • Added toPolymorphicValue helper function for direct conversion from py::handle to PolymorphicValue

  • Replaced torch::jit::toIValue calls with direct toPolymorphicValue in from_pyiterable

  • Eliminates intermediate IValue conversion step for improved performance

  • Achieves 7% performance improvement in all python tests (317.45s vs 340.92s)

Changes walkthrough

Relevant files
Enhancement
direct_utils.cpp
Add direct py::handle to PolymorphicValue conversion         

python/python_direct/direct_utils.cpp

  • Added toPolymorphicValue helper function that converts py::handle
    directly to PolymorphicValue
  • Supports Tensor, bool, int64, double, and complex types
  • Updated from_pyiterable to use toPolymorphicValue instead of
    torch::jit::toIValue
  • Eliminates unnecessary intermediate IValue conversion for better
    performance
  • +22/-2   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Type Coverage Gap

    The new toPolymorphicValue function only handles 5 specific types (Tensor, bool, int64_t, double, complex), while torch::jit::toIValue may support additional types. This could break existing functionality if users pass unsupported types that were previously handled by the PyTorch function.

    PolymorphicValue toPolymorphicValue(const py::handle& obj) {
      static py::object torch_Tensor = py::module_::import("torch").attr("Tensor");
      if (py::isinstance(obj, torch_Tensor)) {
        return PolymorphicValue(py::cast<at::Tensor>(obj));
      } else if (py::isinstance<py::bool_>(obj)) {
        return PolymorphicValue(py::cast<bool>(obj));
      } else if (py::isinstance<py::int_>(obj)) {
        return PolymorphicValue(py::cast<int64_t>(obj));
      } else if (py::isinstance<py::float_>(obj)) {
        return PolymorphicValue(py::cast<double>(obj));
      } else if (py::isinstance<std::complex<double>>(obj)) {
        return PolymorphicValue(py::cast<std::complex<double>>(obj));
      }
      NVF_THROW("Cannot convert provided py::handle to a PolymorphicValue.");
    Error Handling Consistency

    The new function throws NVF_THROW for unsupported types, but the original torch::jit::toIValue might have different error handling behavior. Need to verify error messages and handling are consistent with existing expectations.

    NVF_THROW("Cannot convert provided py::handle to a PolymorphicValue.");

    Test failures

    • (High, 1) NVFuser multi-device CUDA system-not-ready error in LowerCollectiveCudaAndNcclTest

      Test Name GB200 (dist.) Source
      LowerCollectiveCudaAndNcclTest.Broadcast/memcpy_2MB Link
    • (Medium, 3) NVFuser internal assert on mismatched-input-type check (test_python_direct.test_mismatched_input_types)

      Test Name A100 GB200 H100 Source
      tests.python.direct.test_python_direct.test_mismatched_input_types

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 7, 2026

    Greptile Summary

    Replaced torch::jit::toIValue with a custom toPolymorphicValue function in from_pyiterable to convert Python objects directly to PolymorphicValue, eliminating the intermediate IValue conversion step.

    • Added new toPolymorphicValue helper function that handles: torch.Tensor, bool, int, float, and complex types
    • Correctly checks bool before int to handle Python's type hierarchy where bool is a subclass of int
    • Uses static initialization for torch_Tensor type object for efficient repeated checks
    • Performance improvements: 7% faster on all Python tests (317.45s vs 340.92s)
    • This change is isolated to the direct Python API (python_direct) and sets the foundation for future nanobind migration

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The implementation correctly handles all supported types (Tensor, bool, int64_t, double, complex), properly checks bool before int to handle Python's type hierarchy, and maintains functional equivalence with the previous torch::jit::toIValue approach while improving performance by 7%. The change is well-tested and isolated to a single file.
    • No files require special attention

    Important Files Changed

    Filename Overview
    python/python_direct/direct_utils.cpp Adds toPolymorphicValue function to convert py::handle directly to PolymorphicValue, bypassing the intermediate IValue conversion for improved performance

    Sequence Diagram

    sequenceDiagram
        participant Python as Python Code
        participant Direct as from_pyiterable
        participant New as toPolymorphicValue
        participant KAH as KernelArgumentHolder
        
        Note over Python,KAH: Before (using torch::jit::toIValue)
        Python->>Direct: py::iterable with py::handle objects
        Direct->>Direct: torch::jit::toIValue(obj)
        Direct->>Direct: IValue returned
        Direct->>KAH: args.push(IValue)
        KAH->>KAH: IValueToPolymorphicValue
        KAH->>KAH: Store PolymorphicValue
        
        Note over Python,KAH: After (direct conversion)
        Python->>Direct: py::iterable with py::handle objects
        Direct->>New: toPolymorphicValue(obj)
        New->>New: Check type (Tensor/bool/int/float/complex)
        New->>New: Cast to appropriate C++ type
        New->>Direct: Return PolymorphicValue
        Direct->>KAH: args.push(PolymorphicValue)
        KAH->>KAH: Store PolymorphicValue
    
    Loading

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants