From d830281aacc947a8f6455991fc0c148dbd7c3d1e Mon Sep 17 00:00:00 2001 From: SuperDong <16302010007@fudan.edu.cn> Date: Wed, 5 Jun 2024 14:05:36 +0800 Subject: [PATCH] fix default parameter in avg_pool --- frontend/guard_tracker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 17274bf..5d9885a 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -269,15 +269,15 @@ def record_function(self, pargs, pkwargs = self.as_node_args_kwargs(args, kwargs) if func == torch.nn.functional.avg_pool2d: # avg_pool2d only supports integer or tuple(with two int values) as inputs - if isinstance(pargs[1], tuple): + if len(pargs) >= 2 and isinstance(pargs[1], tuple): for i in pargs[1]: if isinstance(i, torch.fx.Node): raise ValueError("cannot convert tensor in avg_pool2d") - if isinstance(args[1], tuple): + if len(args) >= 2 and isinstance(args[1], tuple): for i in args[1]: if torch.is_tensor(i): raise ValueError("cannot convert tensor in avg_pool2d") - elif torch.is_tensor(args[1]): + elif len(args) >= 2 and torch.is_tensor(args[1]): raise ValueError("cannot convert tensor in avg_pool2d") if func in fx_graph_inplace_functions: scalar = None