@@ -149,7 +149,8 @@ struct CUDAKernelTy : public GenericKernelTy {
149149 // The maximum number of threads cannot exceed the maximum of the kernel.
150150 MaxNumThreads = std::min (MaxNumThreads, (uint32_t )MaxThreads);
151151
152- return Plugin::success ();
152+ // Retrieve the size of the arguments.
153+ return initArgsSize ();
153154 }
154155
155156 // / Launch the CUDA kernel function.
@@ -173,11 +174,32 @@ struct CUDAKernelTy : public GenericKernelTy {
173174 }
174175
175176private:
177+ // / Initialize the size of the arguments.
178+ Error initArgsSize () {
179+ CUresult Res;
180+ size_t ArgOffset, ArgSize;
181+ size_t Arg = 0 ;
182+
183+ ArgsSize = 0 ;
184+
185+ // Find the last argument to know the total size of the arguments.
186+ while ((Res = cuFuncGetParamInfo (Func, Arg++, &ArgOffset, &ArgSize)) ==
187+ CUDA_SUCCESS)
188+ ArgsSize = ArgOffset + ArgSize;
189+
190+ if (Res != CUDA_ERROR_INVALID_VALUE)
191+ return Plugin::check (Res, " error in cuFuncGetParamInfo: %s" );
192+ return Plugin::success ();
193+ }
194+
176195 // / The CUDA kernel function to execute.
177196 CUfunction Func;
178197 // / The maximum amount of dynamic shared memory per thread group. By default,
179198 // / this is set to 48 KB.
180199 mutable uint32_t MaxDynCGroupMemLimit = 49152 ;
200+
201+ // / The size of the kernel arguments.
202+ size_t ArgsSize;
181203};
182204
183205// / Class wrapping a CUDA stream reference. These are the objects handled by the
@@ -1430,16 +1452,23 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
14301452 AsyncInfoWrapperTy &AsyncInfoWrapper) const {
14311453 CUDADeviceTy &CUDADevice = static_cast <CUDADeviceTy &>(GenericDevice);
14321454
1455+ // The args size passed in LaunchParams may have tail padding, which is not
1456+ // accepted by the CUDA driver.
1457+ if (ArgsSize > LaunchParams.Size )
1458+ return Plugin::error (ErrorCode::INVALID_ARGUMENT,
1459+ " mismatch in kernel arguments" );
1460+
14331461 CUstream Stream;
14341462 if (auto Err = CUDADevice.getStream (AsyncInfoWrapper, Stream))
14351463 return Err;
14361464
14371465 uint32_t MaxDynCGroupMem =
14381466 std::max (KernelArgs.DynCGroupMem , GenericDevice.getDynamicMemorySize ());
14391467
1468+ size_t ConfigArgsSize = ArgsSize;
14401469 void *Config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, LaunchParams.Data ,
14411470 CU_LAUNCH_PARAM_BUFFER_SIZE,
1442- reinterpret_cast <void *>(&LaunchParams. Size ),
1471+ reinterpret_cast <void *>(&ConfigArgsSize ),
14431472 CU_LAUNCH_PARAM_END};
14441473
14451474 // If we are running an RPC server we want to wake up the server thread
0 commit comments