Skip to content

Commit

Permalink
fix tensor_to_int in avg_pool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Jun 3, 2024
1 parent a6418be commit 9e6008d
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,18 @@ def record_function(self,
func = torch._C._set_grad_enabled
kwargs = {}
pargs, pkwargs = self.as_node_args_kwargs(args, kwargs)
if func == torch.nn.functional.avg_pool2d:
# avg_pool2d only supports integer or tuple(with two int values) as inputs
if isinstance(pargs[1], tuple):
for i in pargs[1]:
if isinstance(i, torch.fx.Node):
raise ValueError("cannot convert tensor in avg_pool2d")
if isinstance(args[1], tuple):
for i in args[1]:
if torch.is_tensor(i):
raise ValueError("cannot convert tensor in avg_pool2d")
elif torch.is_tensor(args[1]):
raise ValueError("cannot convert tensor in avg_pool2d")
if func in fx_graph_inplace_functions:
scalar = None
node = None
Expand Down

0 comments on commit 9e6008d

Please sign in to comment.