Description
🐛 Describe the bug
Hi,
I am trying to create a sample torch model and then convert it into ExecuTorch model (*.pte) using torch.export.export()
In my model there are two input parameters (one is tensor whose dimensions can vary and other is a scalar so it doesn't have any dimensions). My model will just scale the provided tensor by the scale factor
The model script is as shown:
import torch
import torch.nn as nn
from torch.export import export
from torch.export import Dim
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, scale: torch.Tensor):
# scale is a scalar tensor (int wrapped as tensor)
return x * scale
model = MyModel()
# Define dynamic shapes
# Let x be of shape [batch, channels] with dynamic batch
dynamic_batch = Dim("batch")
example_x = torch.randn(4, 10)
example_scale = torch.tensor(3)
_w = Dim('w', min=125, max=4000)
w = _w * 4
_h = Dim('h', min=1, max=8)
h = _h * 16
dynamic_shapes = {
"x": {
0: h,
1: w
}
}
# Export the model with dynamic shape
exported = export(
model,
(example_x, example_scale), dynamic_shapes = dynamic_shapes
)
executorch_program = to_edge_transform_and_lower(
exported
)
executorch_program = executorch_program.to_executorch()
out_path = "dynamic.pte"
with open(out_path, "wb") as file:
file.write(executorch_program.buffer)
But I am facing issues :
torch._dynamo.exc.UserError: When "dynamic_shapes" is specified as a dict, its top-level keys must be the arg names ['x', 'scale'] of
inputs, but here they are ['x']. Alternatively, you could also ignore arg names entirely and specify
dynamic_shapesas a list/tuple matching
inputs. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
I have been facing such issues several times when the model has atleast one dynamic input along with a static one.
So, can some one please help me in resolving these issues, if there are any work arounds?
Versions
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.6.87.1-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 18
On-line CPU(s) list: 0-17
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) Ultra 5 125H
CPU family: 6
Model: 170
Thread(s) per core: 2
Core(s) per socket: 9
Socket(s): 1
Stepping: 4
BogoMIPS: 5990.39
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 432 KiB (9 instances)
L1i cache: 576 KiB (9 instances)
L2 cache: 18 MiB (9 instances)
L3 cache: 18 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-17
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
cc @JacobSzwejbka @angelayi @mergennachin @iseeyuan @lucylq @helunwencser @tarun292 @kimishpatel @jackzhxng
Metadata
Metadata
Assignees
Labels
Type
Projects
Status