Skip to content

Fix temp memory allocation issue in torch.topk operations #12810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class Module final {

MallocMemoryAllocator runtime_allocator_;

MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
MallocMemoryAllocator temp_allocator_{};

std::vector<std::vector<uint8_t>> non_const_buffers_;

Expand Down Expand Up @@ -1061,7 +1061,7 @@ class ProgramMemory {

MallocMemoryAllocator runtime_allocator_;

MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
MallocMemoryAllocator temp_allocator_{};

std::vector<std::vector<uint8_t>> non_const_buffers_;

Expand Down
219 changes: 219 additions & 0 deletions test/end2end/test_temp_allocator_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#!/usr/bin/env python3
"""
Test to verify the fix for temp memory allocation issue in torch.topk operations.

This test specifically checks that the MallocMemoryAllocator fix in pybindings.cpp
resolves the "Memory allocation failed" error when executing operations that
require temporary memory allocation.
"""

import torch
import tempfile
import os
from pathlib import Path
from torch.export import export
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.runtime import Verification, Runtime, Program, Method


class TopKModel(torch.nn.Module):
"""Model that uses torch.topk operation which requires temp memory allocation."""

def __init__(self, k=3) -> None:
super().__init__()
self.k = k

def forward(self, x) -> torch.Tensor:
# This operation requires temporary memory allocation
top_values, top_indices = torch.topk(x, self.k)
return top_values, top_indices

Check failure on line 30 in test/end2end/test_temp_allocator_fix.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY return-value

Incompatible return value type (got "tuple[Tensor, Tensor]", expected "Tensor") To disable, use ` # type: ignore[return-value]`


class TopKModelWithOut(torch.nn.Module):
"""Model that uses torch.topk with out parameter which also requires temp memory."""

def __init__(self, k=3) -> None:
super().__init__()
self.k = k

def forward(self, x) -> torch.Tensor:
top_values = torch.ones(x.shape[0], self.k, dtype=torch.float32)
top_indices = torch.ones(x.shape[0], self.k, dtype=torch.long)
torch.topk(x.contiguous(), self.k, out=(top_values, top_indices))
return top_values, top_indices

Check failure on line 44 in test/end2end/test_temp_allocator_fix.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY return-value

Incompatible return value type (got "tuple[Tensor, Tensor]", expected "Tensor") To disable, use ` # type: ignore[return-value]`


def test_topk_without_out_parameter():
"""Test torch.topk without out parameter."""
print("Testing torch.topk without out parameter...")

model = TopKModel(k=5)
example_input = (torch.randn(3, 100),)

# Export and compile the model
with torch.no_grad():
aten_dialect = export(model, example_input)

backend_dialect = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)

executorch_dialect = backend_dialect.to_executorch()

# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
temp_path = f.name

try:
executorch_dialect.save(temp_path)

# Load and execute with ExecuTorch runtime
et_runtime = Runtime.get()
program = et_runtime.load_program(
Path(temp_path),
verification=Verification.Minimal,
)

forward = program.load_method("forward")
outputs = forward.execute(example_input)

print(f"✓ Successfully executed topk model: {example_input[0].shape} -> {outputs[0].shape}")
return True

finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)


def test_topk_with_out_parameter():
"""Test torch.topk with out parameter (original failing case)."""
print("Testing torch.topk with out parameter...")

model = TopKModelWithOut(k=3)
example_input = (torch.randn(2, 256),)

# Export and compile the model
with torch.no_grad():
aten_dialect = export(model, example_input)

backend_dialect = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)

executorch_dialect = backend_dialect.to_executorch()

# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
temp_path = f.name

try:
executorch_dialect.save(temp_path)

# Load and execute with ExecuTorch runtime
et_runtime = Runtime.get()
program = et_runtime.load_program(
Path(temp_path),
verification=Verification.Minimal,
)

forward = program.load_method("forward")
outputs = forward.execute(example_input)

print(f"✓ Successfully executed topk model with out parameter: {example_input[0].shape} -> {outputs[0].shape}")
return True

finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)


def test_larger_topk_operation():
"""Test larger topk operation that would require more temporary memory."""
print("Testing larger topk operation...")

model = TopKModel(k=50)
example_input = (torch.randn(5, 1000),)

# Export and compile the model
with torch.no_grad():
aten_dialect = export(model, example_input)

backend_dialect = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)

executorch_dialect = backend_dialect.to_executorch()

# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
temp_path = f.name

try:
executorch_dialect.save(temp_path)

# Load and execute with ExecuTorch runtime
et_runtime = Runtime.get()
program = et_runtime.load_program(
Path(temp_path),
verification=Verification.Minimal,
)

forward = program.load_method("forward")
outputs = forward.execute(example_input)

print(f"✓ Successfully executed large topk model: {example_input[0].shape} -> {outputs[0].shape}")
return True

finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)


def main():
"""Run all tests to verify the temp memory allocation fix."""
print("Testing temp memory allocation fix for torch.topk operations")
print("=" * 60)

tests = [
test_topk_without_out_parameter,
test_topk_with_out_parameter,
test_larger_topk_operation,
]

passed = 0
failed = 0

for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"✗ Test {test.__name__} failed with exception: {e}")
failed += 1

print("\n" + "=" * 60)
print(f"Test Results: {passed} passed, {failed} failed")

if failed == 0:
print("✓ All tests passed! The temp memory allocation fix is working correctly.")
return True
else:
print("✗ Some tests failed. The fix may not be working correctly.")
return False


if __name__ == "__main__":
success = main()
exit(0 if success else 1)
Loading