Skip to content

Conversation

@sbryngelson
Copy link
Member

@sbryngelson sbryngelson commented Jan 9, 2026

PR Type

Enhancement, Tests, Bug fix


Description

  • Robust MG convergence with multiple tolerance criteria: Implemented absolute, RHS-relative, and initial-residual relative convergence checks with L2 norm and L∞ safety cap for more reliable Poisson solver termination

  • CUDA Graph acceleration for V-cycle: Added CudaSmootherGraph and CudaVCycleGraph classes to capture smoother kernels and full V-cycles as CUDA Graphs, reducing kernel launch overhead with up to 4.9× speedup

  • Fused residual and norm computation: Implemented compute_residual_and_norms() for efficient single-pass GPU computation of residuals and norms

  • Enhanced GPU memory management: Updated 60+ OpenMP target pragmas from map(present:...) to map(present, alloc:...) across all GPU kernels for proper device memory allocation

  • Null-space handling for singular problems: Added has_nullspace() and fix_nullspace() methods for proper handling of periodic Poisson problems

  • Fixed and adaptive cycle modes: Implemented configurable fixed-cycle and adaptive-cycle modes with optional convergence checking after specified iterations

  • Comprehensive testing and benchmarking: Added 8 new test/benchmark files covering V-cycle graph stress tests, MG tuning parameter optimization, physics equivalence verification, and solver performance comparisons

  • Configuration enhancements: Extended PoissonConfig with new convergence criteria, smoother tuning parameters (nu1, nu2, chebyshev_degree), and CUDA Graph control flags

  • Profiling infrastructure: Added comprehensive NVTX profiling suite and Nsight Systems automation script for performance analysis across multiple solver configurations

  • Documentation: Included profiling results demonstrating CUDA Graph effectiveness and solver selection guidance


Diagram Walkthrough

flowchart LR
  A["Poisson Solver Config"] -->|"tol_abs, tol_rhs, tol_rel"| B["Robust Convergence Checker"]
  C["V-cycle Operations"] -->|"Capture as Graph"| D["CudaVCycleGraph"]
  D -->|"Reduce Launch Overhead"| E["4.9× Speedup"]
  F["Residual + Norms"] -->|"Fused Computation"| G["Single GPU Pass"]
  H["OpenMP Kernels"] -->|"map present,alloc"| I["Proper GPU Memory"]
  B -->|"Check Convergence"| J["Adaptive/Fixed Cycles"]
  G -->|"Feed"| B
Loading

File Walkthrough

Relevant files
Enhancement
9 files
poisson_solver_multigrid.cpp
Robust MG convergence with CUDA Graphs and fused norm computation

src/poisson_solver_multigrid.cpp

  • Added robust convergence checking with L2 norm, L∞ safety cap, and
    multiple tolerance criteria (absolute, RHS-relative, initial-residual
    relative)
  • Implemented fused compute_residual_and_norms() function for efficient
    single-pass residual + norm computation on GPU
  • Added CUDA Graph support for smoother kernels and full V-cycle graphs
    to reduce kernel launch overhead
  • Enhanced GPU synchronization handling with explicit documentation of
    data dependencies and stream management
  • Refactored vcycle() to accept degree parameter and support multiple
    smoothing passes per level
  • Added has_nullspace() and fix_nullspace() methods for proper handling
    of singular Poisson problems
  • Implemented fixed-cycle and adaptive-cycle modes with optional
    convergence checking
  • Updated all OpenMP target pragmas from map(present:...) to
    map(present, alloc:...) for better GPU memory management
+844/-113
mg_cuda_kernels.cpp
CUDA kernel implementation for multigrid with graph capture

src/mg_cuda_kernels.cpp

  • New file implementing CUDA kernels for multigrid smoother operations
    (Chebyshev, Jacobi, boundary conditions)
  • Provides 3D kernels for residual computation, restriction (27-point
    stencil), and prolongation (trilinear interpolation)
  • Implements CudaSmootherGraph class for capturing smoother kernel
    sequences as CUDA Graphs
  • Implements CudaMGContext for managing multiple smoother graphs across
    grid levels
  • Implements CudaVCycleGraph for capturing entire V-cycle as single CUDA
    Graph with recursive level handling
  • Includes fused periodic BC kernel to eliminate separate BC pass
    overhead
  • Provides kernel launch wrappers with proper grid/block sizing and
    stream management
+933/-0 
solver.cpp
GPU memory mapping fixes and robust MG convergence diagnostics

src/solver.cpp

  • Changed 60+ OpenMP target pragmas from map(present:...) to
    map(present, alloc:...) to properly allocate device memory for arrays
    that may not be pre-mapped
  • Added robust MG convergence criteria with three independent
    convergence checks: absolute tolerance, RHS-relative, and
    initial-residual relative
  • Implemented comprehensive Poisson solver diagnostics with detailed
    residual norms and convergence ratios for MG solver analysis
  • Added post-projection divergence diagnostic output to measure actual
    projection quality (max|div(u)|) after velocity correction
  • Configured PoissonConfig with new MG tuning parameters (nu1, nu2,
    Chebyshev degree, fixed/adaptive cycles, CUDA Graph support)
+161/-88
mg_cuda_kernels.hpp
CUDA Graph infrastructure for multigrid V-cycle acceleration

include/mg_cuda_kernels.hpp

  • New 332-line header defining CUDA kernel infrastructure for multigrid
    Poisson solver acceleration
  • Provides CudaSmootherGraph class for capturing individual level
    smoothers and CudaVCycleGraph for full V-cycle graph capture
  • Defines VCycleGraphFingerprint struct for validity checking when BC or
    grid parameters change
  • Declares kernel launch functions for Chebyshev smoothing, boundary
    conditions, residual computation, restriction, and prolongation
+332/-0 
poisson_solver_multigrid.hpp
MG Solver Residual Tracking and CUDA Graph Support             

include/poisson_solver_multigrid.hpp

  • Adds forward declarations for CUDA Graph context classes
  • Introduces new residual tracking methods: residual_l2(), rhs_norm(),
    rhs_norm_l2(), initial_residual(), initial_residual_l2()
  • Adds CUDA Graph support infrastructure with context and V-cycle graph
    pointers
  • Implements has_nullspace() and fix_nullspace() methods for robust
    null-space handling
  • Adds compute_residual_and_norms() for fused residual computation
+73/-4   
poisson_solver.hpp
Poisson Solver Configuration Enhancements                               

include/poisson_solver.hpp

  • Expands PoissonConfig with robust convergence criteria: tol_abs,
    tol_rhs, tol_rel, check_interval, use_l2_norm, linf_safety_factor
  • Adds fixed-cycle mode configuration: fixed_cycles, adaptive_cycles,
    check_after
  • Introduces MG smoother tuning parameters: nu1, nu2, chebyshev_degree
  • Adds CUDA Graph acceleration flag: use_vcycle_graph
  • Marks legacy tol field as deprecated in favor of new tolerance
    criteria
+33/-1   
main_taylor_green_3d.cpp
Taylor-Green 3D Simulation Time Control Flexibility           

app/main_taylor_green_3d.cpp

  • Modifies T_final initialization to support config file and
    command-line precedence
  • Adds T_final_from_cmdline flag to track explicit command-line
    specification
  • Implements three-tier fallback: command-line > config file > default
    (10.0)
  • Allows config file to directly set max_iter for profiling without
    T_final override
+20/-3   
gpu_utils.hpp
GPU Utilities for Async Operations and Pointer Mapping     

include/gpu_utils.hpp

  • Adds get_device_ptr() template function to convert OpenMP-mapped host
    pointers to device pointers
  • Implements sync() function for GPU work synchronization using OpenMP
    taskwait
  • Adds GPU_PARALLEL_FOR_ASYNC macro for asynchronous kernel launches
    with nowait
  • Provides CPU no-op versions for non-GPU builds
+27/-0   
config.cpp
Configuration File Support for MG Convergence Parameters 

src/config.cpp

  • Adds loading of T_final from configuration file
  • Loads new MG convergence parameters: poisson_tol_abs, poisson_tol_rhs,
    poisson_tol_rel, poisson_check_interval, poisson_use_l2_norm,
    poisson_linf_safety
  • Loads fixed-cycle configuration: poisson_fixed_cycles
  • Enables config file-based tuning of robust convergence criteria
+10/-0   
Formatting
1 files
gpu_kernels.cpp
Standardize GPU memory mapping clauses across kernels       

src/gpu_kernels.cpp

  • Updated all OpenMP target pragmas from map(present:...) to
    map(present, alloc:...) for consistency with GPU memory management
    improvements
  • Applied changes across gradient computation, MLP feature extraction,
    TBNN processing, and turbulence closure kernels
  • Updated comments to reflect new mapping clause semantics
+17/-17 
Tests
10 files
ci.sh
Add V-cycle CUDA Graph stress test to CI pipeline               

scripts/ci.sh

  • Added new CI test test_vcycle_graph_stress to validate V-cycle CUDA
    Graph functionality
  • Test checks BC alternation (graph recapture), convergence parity, and
    anisotropic grid handling
  • Runs with MG_USE_VCYCLE_GRAPH=1 environment variable to enable graph
    mode
+4/-0     
test_vcycle_graph_stress.cpp
V-cycle CUDA Graph stress tests for BC alternation and convergence

tests/test_vcycle_graph_stress.cpp

  • New comprehensive stress test suite for V-cycle CUDA Graph
    functionality with 360 lines of test code
  • Test 1: BC type alternation (Dirichlet↔Neumann↔Periodic) verifying
    graph recapture across 10 iterations
  • Test 2: Convergence curve parity comparing graphed vs non-graphed
    solver paths across 1-8 V-cycles
  • Test 3: Mixed BCs on anisotropic grids (64×32×16 with 4:2:1 aspect
    ratio) validating robustness
+360/-0 
bench_mg_tuning.cpp
MG tuning parameter benchmark with quality and performance metrics

tests/bench_mg_tuning.cpp

  • New 330-line benchmark suite for MG tuning parameter optimization
    (nu1, nu2, Chebyshev degree, cycle counts)
  • Measures wall time per step, post-projection divergence (L2 and L∞),
    and kinetic energy drift with statistical analysis
  • Tests 9 configurations including baseline, asymmetric smoothing, and
    adaptive cycle modes on configurable grid sizes
  • Provides mean±stddev statistics over multiple repeats for robust
    performance characterization
+330/-0 
profile_comprehensive.cpp
Comprehensive NVTX profiling suite for solver configurations

app/profile_comprehensive.cpp

  • New 387-line comprehensive profiling suite for NVTX-based performance
    analysis across multiple configurations
  • Tests 14 configurations covering BC variations (periodic, channel,
    duct), Poisson solvers (MG, MG+Graph, FFT, FFT1D, HYPRE), and
    turbulence models (laminar, Smagorinsky, SST k-omega)
  • Includes warmup step, timed execution, and throughput reporting
    (Mcells/s) for each configuration
  • Designed for use with nsys profile for GPU kernel timeline analysis
    and CUDA Graph effectiveness measurement
+387/-0 
test_mg_physics_match.cpp
MG Physics Equivalence Verification Test                                 

tests/test_mg_physics_match.cpp

  • New comprehensive test verifying fixed-cycle MG produces equivalent
    physics to converged MG
  • Implements Taylor-Green vortex simulation with both solver modes and
    compares KE, divergence, and velocity
  • Includes helper functions for computing kinetic energy, max
    divergence, and max velocity magnitude
  • Validates physics match with relative tolerance thresholds and
    detailed step-by-step comparison
+238/-0 
bench_mg_cuda_graphs.cpp
MG CUDA Graphs Performance Benchmark                                         

tests/bench_mg_cuda_graphs.cpp

  • New benchmark for MG solver performance with/without CUDA Graphs on
    large 3D grids
  • Supports both GPU path using solve_device() with persistent GPU data
    and CPU fallback
  • Includes JIT warmup and configurable fixed-cycle vs convergence modes
  • Benchmarks grids from 64³ to 192³ with timing statistics
+192/-0 
bench_mg_bc_sweep.cpp
MG Boundary Condition Robustness Sweep                                     

tests/bench_mg_bc_sweep.cpp

  • New benchmark sweeping MG smoother parameters (nu1, nu2) across
    different BC types
  • Tests Channel (PWP) and Duct (PWW) configurations with various
    smoother combinations
  • Measures performance and divergence quality metrics for tuning
    validation
  • Compares new (nu1=3,nu2=1) vs baseline (nu1=2,nu2=2) configurations
+191/-0 
bench_fft_vs_mg.cpp
FFT vs MG Poisson Solver Performance Comparison                   

tests/bench_fft_vs_mg.cpp

  • New benchmark comparing FFT vs MG Poisson solver performance on
    periodic grids
  • Runs Taylor-Green vortex with both solvers and measures full timestep
    cost
  • Tests grids from 64³ to 192³ with configurable fixed-cycle mode
  • Provides performance ratio comparison for solver selection guidance
+122/-0 
bench_256.cpp
Large-Scale MG Solver Benchmark (256³)                                     

tests/bench_256.cpp

  • New benchmark for MG solver on large 256³ grid
  • Initializes RHS with sin pattern and runs 5 trials with convergence
    tolerance
  • Measures average solve time and throughput in Mcells/s
  • Provides baseline performance data for large-scale grids
+63/-0   
test_poisson_unified.cpp
Poisson Test Updates for New Convergence Criteria               

tests/test_poisson_unified.cpp

  • Updates test_nullspace_periodic() to explicitly set new convergence
    tolerance fields
  • Sets tol_abs, tol_rhs, tol_rel and use_l2_norm for consistent behavior
  • Updates test_3d_gpu_convergence() with same tolerance configuration
  • Ensures tests work with new robust convergence criteria
+10/-2   
Configuration changes
1 files
config.hpp
Configuration parameters for robust MG convergence and CUDA Graph
support

include/config.hpp

  • Added T_final parameter for unsteady simulations (alternative to
    max_iter-based termination)
  • Introduced robust MG convergence criteria: poisson_tol_abs,
    poisson_tol_rhs, poisson_tol_rel with independent checks
  • Added MG smoother tuning parameters: poisson_nu1, poisson_nu2,
    poisson_chebyshev_degree with auto-tuning support
  • Added adaptive fixed-cycle mode: poisson_adaptive_cycles,
    poisson_check_after for hybrid convergence checking
  • Added CUDA Graph control: poisson_use_vcycle_graph flag to
    enable/disable V-cycle graph acceleration
  • Increased default poisson_max_iter from 5 to 20 as safety limit for
    convergence checking
+27/-3   
Bug fix
2 files
poisson_solver_fft.cpp
FFT Solver GPU Memory Map Clause Updates                                 

src/poisson_solver_fft.cpp

  • Updates OpenMP target map clauses from map(present:...) to
    map(present, alloc:...)
  • Applied to 4 kernel regions: pack_rhs_with_sum, pack_rhs,
    unpack_solution, unpack_and_apply_bc
  • Also updates apply_bc_device for X, Y, Z boundary kernels
  • Improves GPU memory management for persistent data handling
+7/-7     
poisson_solver_fft2d.cpp
FFT2D Solver GPU Memory Map Clause Updates                             

src/poisson_solver_fft2d.cpp

  • Updates OpenMP target map clauses in solve_device() from
    map(present:...) to map(present, alloc:...)
  • Applied to 2 kernel regions: pack RHS and unpack solution operations
  • Improves GPU memory management consistency with 3D FFT solver
+2/-2     
Documentation
3 files
run_nsys_profiles.sh
Nsight Systems Profiling Automation Script                             

scripts/run_nsys_profiles.sh

  • New comprehensive profiling script for NNCFD Poisson solvers using
    Nsight Systems
  • Profiles 4 cases: Taylor-Green FFT, Channel HYPRE, Duct FFT1D,
    Taylor-Green MG
  • Includes build verification, nsys availability checks, and automated
    report generation
  • Extracts NVTX ranges, CUDA kernels, and memory operation statistics
+265/-0 
taylor_green_fft_stats.txt
FFT Solver Profiling Results and Statistics                           

profiles/taylor_green_fft_stats.txt

  • Profiling results from Nsight Systems for Taylor-Green FFT solver at
    312³ grid
  • Contains NVTX range summary showing poisson_solve dominates at 45.5%
    of time
  • Includes CUDA API statistics with cudaStreamSynchronize overhead
    analysis
  • Provides CUDA kernel timing breakdown and GPU memory operation
    statistics
+208/-0 
profiling_results_128cubed.md
Comprehensive MG Solver Profiling Results Documentation   

docs/profiling_results_128cubed.md

  • Comprehensive profiling documentation for 128³ grid configurations
  • Demonstrates CUDA Graph provides 4.9× speedup for MG solver
  • Includes detailed timing breakdown by phase and comparison
    with/without graphs
  • Provides recommendations for solver selection and CUDA Graph usage
+145/-0 
Additional files
4 files
CMakeLists.txt +29/-1   
taylor_green_fft_kernels.txt +79/-0   
taylor_green_fft_memops.txt +17/-0   
taylor_green_fft_nvtx.txt +24/-0   

Summary by CodeRabbit

  • New Features

    • GPU offload enhanced with CUDA‑Graph V‑cycle support, richer multigrid tuning (multi‑criteria tolerances, fixed/adaptive cycles, smoother degree), improved solver diagnostics, and final-time precedence (CLI > config > default).
  • Tests

    • Added benchmarks and stress tests covering MG, CUDA‑Graph paths, BC sweeps, tuning, FFT vs MG comparisons, and physics equivalence checks.
  • Documentation

    • New profiling reports, profiling scripts, and user-facing docs describing CUDA‑Graph V‑cycle optimization and reproduction steps.

✏️ Tip: You can customize this high-level summary in your review settings.

sbryngelson and others added 16 commits January 7, 2026 17:04
Implement tolerance-based early termination for multigrid Poisson solver
with dt-independent convergence behavior:

- Add relative tolerance criteria to PoissonConfig:
  - tol_rhs: ||r||/||b|| (RHS-relative, primary criterion)
  - tol_rel: ||r||/||r0|| (initial-residual relative, backup)
  - tol_abs: absolute tolerance (disabled by default)
  - check_interval: check every N V-cycles to reduce overhead

- Update MG solver (CPU and GPU paths):
  - Compute ||b||_∞ and ||r0||_∞ at start of solve
  - Check convergence every check_interval cycles
  - Exit early when any tolerance criterion is met
  - Store norms in member variables for diagnostics

- Add T_final config option for unsteady simulations

- Enhanced diagnostics (NNCFD_POISSON_DIAGNOSTICS=1):
  - Log ||b||, ||r0||, ||r||/||b||, ||r||/||r0|| per solve
  - Add post-projection divergence check

Performance improvement for 312³ Taylor-Green:
- Before (50 fixed cycles): 328 ms/solve
- After (8-10 cycles with tol_rhs=1e-3): 66 ms/solve (5x speedup)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Performance optimizations for multigrid Poisson solver:

1. Fused residual computation + norm calculation
   - Single GPU kernel computes r = f - L(u), ||r||_∞, and ||r||_2
   - Replaces separate compute_residual() + compute_max_residual()
   - Reduces memory bandwidth and kernel launch overhead

2. L2 norm convergence option (default: enabled)
   - cfg.use_l2_norm selects L2 vs L∞ for convergence checking
   - L2 is smoother, less sensitive to single "hot cells"
   - Converges in fewer V-cycles (~8 vs 9-10 for L∞)

3. L∞ safety cap to prevent L2 from hiding bad cells
   - cfg.linf_safety_factor (default: 10) enforces loose L∞ bound
   - Even with L2 convergence, ||r||_∞/||b||_∞ ≤ tol_rhs * 10

Timing (192³ = 7M points, tol_rhs=1e-3):
- L2 convergence: 106ms, 8 cycles, ||r||₂/||b|| = 5.9e-4
- L∞ convergence: 112ms, 10 cycles, ||r||∞/||b|| = 2.1e-4
- ~5% speedup from L2 convergence

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Since the fused residual+norms computation is now nearly free,
checking convergence every V-cycle avoids overshooting and
costs almost nothing. Timing shows <1ms difference between
check_interval=1 and check_interval=4 at 7M points.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Implements CUDA Graph capture/replay for the multigrid smoother to reduce
kernel launch overhead. Enabled via MG_USE_CUDA_GRAPHS=1 environment variable.

Key changes:
- Add mg_cuda_kernels.hpp/cpp with CUDA-native Chebyshev smoother kernel
- Add CudaSmootherGraph class for graph capture and replay
- Add CudaMGContext for dedicated MG stream management
- Add gpu::get_device_ptr() utility for OpenMP-to-CUDA pointer conversion
- Integrate CUDA Graph path in smooth_chebyshev() with automatic fallback

Note: CUDA Graphs only work for 3D cases due to different array indexing
between OpenMP 2D (idx = j*stride + i) and CUDA 3D (idx = k*plane + j*stride + i).
2D cases automatically fall back to the OpenMP path.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Benchmarks comparing MG solver performance with/without CUDA Graphs:
- bench_mg_cuda_graphs: Tests 64³ to 192³ grids
- bench_256: Tests 256³ grid

Results show 1.1x-3.2x speedup depending on grid size:
- 64³:  3.15x (kernel launch overhead dominates)
- 128³: 1.52x
- 256³: 1.12x (compute time dominates)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add fixed_cycles option to skip convergence checks and D→H transfers
- Add chebyshev_3d_periodic_kernel with wrap indexing (no separate BC pass)
- Update CUDA Graph capture to use fused kernel for fully-periodic grids
- Update benchmark to test fixed_cycles mode

Performance at 192³ (8 V-cycles):
- Baseline (convergence mode): 105.25ms
- Fixed cycles + fused BC: 74.17ms (1.42x faster)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
The fused periodic kernel uses wrap indexing to read from interior cells
directly, avoiding ghost cell reads WITHIN the smoother iterations.
However, other MG operations (compute_residual, restrict, prolongate)
still use standard neighbor indexing (idx-1, idx+stride, etc.) which
reads ghost cells.

Previously, the final BC kernel was skipped for all-periodic cases:
  if (!all_periodic) { launch_bc_kernel(stream); }

This caused garbage ghost cell values after the smoother, leading to
NaN in compute_residual when called from solve_device().

Fix: Always apply BC after the smoother completes, regardless of
whether the fused periodic kernel was used. This ensures ghost cells
are valid for subsequent MG operations.

Performance impact on standalone MG solve() with CUDA Graphs:
  64³:  3.27x speedup (13.56ms → 4.15ms)
  192³: 1.22x speedup (105.16ms → 86.39ms)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
OpenMP's map(present: ptr[0:n]) defaults to tofrom semantics, causing
the runtime to do H→D transfers even when data is already on device.
This was causing ~8GB of unnecessary transfers per benchmark run.

Fix: Change all map(present:) to map(present, alloc:) which tells the
runtime "data must be present, map-type is alloc (no transfer)".

Files modified:
- src/poisson_solver_multigrid.cpp (39 pragmas)
- src/solver.cpp (78 pragmas)
- src/gpu_kernels.cpp (17 pragmas)
- src/poisson_solver_fft.cpp (7 pragmas)
- src/poisson_solver_fft2d.cpp (2 pragmas)

Note: Large H→D transfers still occur during solver initialization
(uploading initial conditions) which is expected behavior.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Use ompx_get_cuda_stream() (NVHPC extension) to get the CUDA stream
that OpenMP uses for target regions. By launching CUDA Graphs on this
same stream, we eliminate cross-stream synchronization overhead:

- Previous: cudaGraphLaunch on internal stream → cudaStreamSync → OpenMP
- Now: cudaGraphLaunch on OpenMP's stream → no explicit sync needed

The subsequent OpenMP target regions naturally wait for the graph to
complete since they're all on the same stream. This should significantly
reduce the ~42% time previously spent in cudaStreamSynchronize.

Added smooth(level, cudaStream_t) overload to CudaMGContext to enable
stream-specific graph execution.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Add CUDA kernels for all V-cycle operations (residual, restrict,
prolongate, zero) and implement CudaVCycleGraph class that captures
the entire multigrid V-cycle as a single CUDA Graph for massive
reduction in kernel launch overhead.

Performance gains (8 V-cycles, solve_device path):
- 64³:  11.5x faster (10.56ms → 0.92ms)
- 128³: 6.4x faster  (15.05ms → 2.36ms)
- 192³: 3.8x faster  (22.73ms → 6.01ms)

Enable via MG_USE_VCYCLE_GRAPH=1 environment variable.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Add robustness improvements to the V-cycle CUDA graph:

1. VCycleGraphFingerprint struct captures all parameters affecting graph
   validity (level count/sizes, coefficients, BCs, nu1/nu2, degree)
2. Automatic graph recapture when BCs change via set_bc()
3. Debug assertion for stream consistency (null stream detection)
4. Stress test (test_vcycle_graph_stress.cpp) covering:
   - BC type alternation (Periodic↔Neumann↔Dirichlet)
   - Convergence curve parity (graphed vs non-graphed)
   - Mixed BCs on anisotropic grids (4:2:1 aspect ratio)
5. JIT variability mitigation documentation in benchmark

Co-Authored-By: Claude Opus 4.5 <[email protected]>
…elper

- Add level_dx/dy/dz to VCycleGraphFingerprint for anisotropic grid support
- Replace debug assert with runtime stream fallback (warns + uses non-graphed path)
- Add has_nullspace() helper with documentation for edge cases
- Add fix_nullspace() to consolidate 4 duplicated nullspace handling blocks

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Add #ifdef USE_GPU_OFFLOAD guards around solve_device() calls with
CPU fallback using regular solve() with ScalarField.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Register test_vcycle_graph_stress as ctest with MG_USE_VCYCLE_GRAPH=1
- Add to ci.sh GPU test suite (runs after GPU Utilization test)
- Tests BC alternation, convergence parity, anisotropic grids

This enables CI coverage for the graphed MG solver path.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Benchmark-driven optimization of multigrid smoother parameters:

Performance at 128³ channel (PWP):
- nu1=3,nu2=1,cyc=8: 12% faster AND 10× better div_L2 vs baseline
- Validated across Channel (PWP) and Duct (PWW) boundary conditions

Changes:
- Add nu1, nu2, chebyshev_degree config parameters (defaults: 3, 1, 4)
- Add adaptive_cycles mode: run check_after cycles, check residual, add more if needed
- Fix adaptive mode CUDA error: use CPU reduction with D→H transfer
  instead of OpenMP target reduction (avoids CUDA/OpenMP stream conflicts)
- Add bench_mg_tuning.cpp: statistical benchmark with mean±stddev
- Add bench_mg_bc_sweep.cpp: validates tuning across BC configurations

Quality contract maintained: div_L2 < 1e-5, div_Linf < 1e-3

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Changes:
- V-cycle graph now enabled by default (was opt-in via env var)
- Add poisson_use_vcycle_graph config option (default: true)
- Environment variable MG_USE_VCYCLE_GRAPH=0 can still disable
- Update nu1/nu2 documentation: optimal for wall BCs is nu1=3, nu2=1
- Remove unused #include <cstring>

The V-cycle graph provides massive speedup by capturing the entire
multigrid V-cycle as a single CUDA graph, eliminating kernel launch
overhead. Users who need to disable it (debugging, compatibility)
can set poisson_use_vcycle_graph=false in config or MG_USE_VCYCLE_GRAPH=0.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@qodo-code-review
Copy link

qodo-code-review bot commented Jan 9, 2026

PR Code Suggestions ✨

Latest suggestions up to 26ed4bd

CategorySuggestion                                                                                                                                    Impact
Incremental [*]
Fix ghost-layer edge ownership checks
Suggestion Impact:The commit updates the Y-low/Y-high and Z-low/Z-high face edge-ownership skip conditions from equality checks (Ng==1-only) to range checks that correctly skip all ghost layers for arbitrary Ng, matching the intent of the suggestion and preventing race conditions. (It also refactors Chebyshev bounds into file-level constants, which is unrelated to the suggestion.)

code diff:

@@ -196,8 +206,8 @@
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
-        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
-        if (i == 0 || i == Nx + Ng) return;
+        // Skip cells owned by x-faces (all x-ghost layers, robust for any Ng)
+        if (i < Ng || i >= Nx + Ng) return;
         int idx = k * plane_stride + 0 * stride + i;
         int idx_int = k * plane_stride + Ng * stride + i;
         int idx_wrap = k * plane_stride + (Ny + Ng - 1) * stride + i;
@@ -218,8 +228,8 @@
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
-        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
-        if (i == 0 || i == Nx + Ng) return;
+        // Skip cells owned by x-faces (all x-ghost layers, robust for any Ng)
+        if (i < Ng || i >= Nx + Ng) return;
         int idx = k * plane_stride + (Ny + Ng) * stride + i;
         int idx_int = k * plane_stride + (Ny + Ng - 1) * stride + i;
         int idx_wrap = k * plane_stride + Ng * stride + i;
@@ -240,8 +250,8 @@
         int ij = remaining;
         int i = ij % (Nx + 2*Ng);
         int j = ij / (Nx + 2*Ng);
-        // Skip cells owned by x-faces or y-faces
-        if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
+        // Skip cells owned by x-faces or y-faces (all ghost layers, robust for any Ng)
+        if (i < Ng || i >= Nx + Ng || j < Ng || j >= Ny + Ng) return;
         int idx = 0 * plane_stride + j * stride + i;
         int idx_int = Ng * plane_stride + j * stride + i;
         int idx_wrap = (Nz + Ng - 1) * plane_stride + j * stride + i;
@@ -262,8 +272,8 @@
         int ij = remaining;
         int i = ij % (Nx + 2*Ng);
         int j = ij / (Nx + 2*Ng);
-        // Skip cells owned by x-faces or y-faces
-        if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
+        // Skip cells owned by x-faces or y-faces (all ghost layers, robust for any Ng)
+        if (i < Ng || i >= Nx + Ng || j < Ng || j >= Ny + Ng) return;
         int idx = (Nz + Ng) * plane_stride + j * stride + i;
         int idx_int = (Nz + Ng - 1) * plane_stride + j * stride + i;
         int idx_wrap = Ng * plane_stride + j * stride + i;

Update the edge-skip checks in bc_3d_kernel to be robust for any number of ghost
cells (Ng >= 1), not just one. Change conditions like i == 0 to i < Ng to
correctly handle ownership of ghost layers and prevent race conditions.

src/mg_cuda_kernels.cpp [194-278]

 // Y-low face (skip x-edges to avoid race with x-faces)
 if (remaining < face_xz) {
     int ik = remaining;
     int i = ik % (Nx + 2*Ng);
     int k = ik / (Nx + 2*Ng);
-    // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
-    if (i == 0 || i == Nx + Ng) return;
+    // Skip cells owned by x-faces (all x-ghost layers)
+    if (i < Ng || i >= Nx + Ng) return;
     int idx = k * plane_stride + 0 * stride + i;
     int idx_int = k * plane_stride + Ng * stride + i;
     int idx_wrap = k * plane_stride + (Ny + Ng - 1) * stride + i;
 ...
 // Y-high face (skip x-edges to avoid race with x-faces)
 if (remaining < face_xz) {
     int ik = remaining;
     int i = ik % (Nx + 2*Ng);
     int k = ik / (Nx + 2*Ng);
-    // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
-    if (i == 0 || i == Nx + Ng) return;
+    // Skip cells owned by x-faces (all x-ghost layers)
+    if (i < Ng || i >= Nx + Ng) return;
     int idx = k * plane_stride + (Ny + Ng) * stride + i;
     int idx_int = k * plane_stride + (Ny + Ng - 1) * stride + i;
     int idx_wrap = k * plane_stride + Ng * stride + i;
 ...
 // Z-low face (skip x/y-edges to avoid races)
 if (remaining < face_xy) {
     int ij = remaining;
     int i = ij % (Nx + 2*Ng);
     int j = ij / (Nx + 2*Ng);
-    // Skip cells owned by x-faces or y-faces
-    if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
+    // Skip cells owned by x-faces or y-faces (all ghost layers)
+    if (i < Ng || i >= Nx + Ng || j < Ng || j >= Ny + Ng) return;
     int idx = 0 * plane_stride + j * stride + i;
     int idx_int = Ng * plane_stride + j * stride + i;
     int idx_wrap = (Nz + Ng - 1) * plane_stride + j * stride + i;
 ...
 // Z-high face (skip x/y-edges to avoid races)
 if (remaining < face_xy) {
     int ij = remaining;
     int i = ij % (Nx + 2*Ng);
     int j = ij / (Nx + 2*Ng);
-    // Skip cells owned by x-faces or y-faces
-    if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
+    // Skip cells owned by x-faces or y-faces (all ghost layers)
+    if (i < Ng || i >= Nx + Ng || j < Ng || j >= Ny + Ng) return;
     int idx = (Nz + Ng) * plane_stride + j * stride + i;
     int idx_int = (Nz + Ng - 1) * plane_stride + j * stride + i;
     int idx_wrap = Ng * plane_stride + j * stride + i;

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a bug in bc_3d_kernel where edge-skip logic for avoiding race conditions is only correct for Ng=1, making the kernel fail for multiple ghost layers.

Medium
Verify device mapping before conversion

Add a check with omp_target_is_present before calling omp_get_mapped_ptr to
ensure the pointer is actually mapped to the device, improving robustness.

include/gpu_utils.hpp [248-255]

 template<typename T>
 inline T* get_device_ptr(T* host_ptr) {
     if (host_ptr == nullptr) return nullptr;
-    int device = omp_get_default_device();
+
+    const int device = omp_get_default_device();
+    if (!omp_target_is_present(host_ptr, device)) return nullptr;
+
     void* dev_ptr = omp_get_mapped_ptr(host_ptr, device);
-    // omp_get_mapped_ptr returns nullptr if pointer is not mapped
     return static_cast<T*>(dev_ptr);
 }
  • Apply / Chat
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies a subtle robustness issue with omp_get_mapped_ptr and proposes a valid safeguard using omp_target_is_present, which prevents potential crashes from passing a host pointer to a device kernel.

Medium
Avoid null graph usage by default

Initialize use_vcycle_graph_ to false by default and only set it to true after
the vcycle_graph_ object is successfully created to prevent potential null
pointer dereferences.

include/poisson_solver_multigrid.hpp [145-149]

 std::unique_ptr<mg_cuda::CudaVCycleGraph> vcycle_graph_;  // Full V-cycle graph
-bool use_vcycle_graph_ = true;       // Full V-cycle graph (DEFAULT ON, disable via config)
+bool use_vcycle_graph_ = false;      // Enable only after successful initialization
 int vcycle_graph_nu1_ = 2;           // Pre-smoothing iterations for graphed V-cycle
 int vcycle_graph_nu2_ = 2;           // Post-smoothing iterations for graphed V-cycle
 int vcycle_graph_degree_ = 4;        // Chebyshev degree for graphed V-cycle
  • Apply / Chat
Suggestion importance[1-10]: 7

__

Why: This is a good defensive programming suggestion that prevents a potential null pointer dereference by ensuring the use_vcycle_graph_ flag is only enabled after the graph object is successfully initialized.

Medium
Possible issue
Prevent boundary indexing overflow

In bc_3d_kernel, use 64-bit size_t for calculating total_boundary and for tid to
prevent potential integer overflow on large grids.

src/mg_cuda_kernels.cpp [129-146]

 __global__ void bc_3d_kernel(
     double* __restrict__ u,
     int Nx, int Ny, int Nz, int Ng,
     int bc_x_lo, int bc_x_hi,
     int bc_y_lo, int bc_y_hi,
     int bc_z_lo, int bc_z_hi,
     double dirichlet_val)
 {
     int stride = Nx + 2 * Ng;
     int plane_stride = stride * (Ny + 2 * Ng);
 
     // Thread covers all boundary cells
-    int tid = blockIdx.x * blockDim.x + threadIdx.x;
-    int total_boundary = (Ny + 2*Ng) * (Nz + 2*Ng) * 2 +  // x faces
-                         (Nx + 2*Ng) * (Nz + 2*Ng) * 2 +  // y faces
-                         (Nx + 2*Ng) * (Ny + 2*Ng) * 2;   // z faces
+    size_t tid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
+    size_t total_boundary =
+        static_cast<size_t>(Ny + 2*Ng) * static_cast<size_t>(Nz + 2*Ng) * 2 +  // x faces
+        static_cast<size_t>(Nx + 2*Ng) * static_cast<size_t>(Nz + 2*Ng) * 2 +  // y faces
+        static_cast<size_t>(Nx + 2*Ng) * static_cast<size_t>(Ny + 2*Ng) * 2;   // z faces
 
     if (tid >= total_boundary) return;
  • Apply / Chat
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a potential integer overflow bug in calculating total_boundary for large grids, which could lead to incorrect behavior or memory corruption.

Medium
Compute diagnostics on device

Fix the divergence diagnostic by computing max_div on the device using an OpenMP
reduction to avoid reading stale host data in GPU builds.

src/solver.cpp [3420-3448]

 {
     static bool div_diagnostics = (std::getenv("NNCFD_POISSON_DIAGNOSTICS") != nullptr);
     static int div_diagnostics_interval = []() {
         const char* env = std::getenv("NNCFD_POISSON_DIAGNOSTICS_INTERVAL");
         int v = env ? std::atoi(env) : 1;
         return (v > 0) ? v : 1;
     }();
     if (div_diagnostics && (iter_ % div_diagnostics_interval == 0)) {
         compute_divergence(VelocityWhich::Current, div_velocity_);  // Divergence of corrected velocity
         double max_div = 0.0;
+
+#ifdef USE_GPU_OFFLOAD
+        double* div_ptr = div_velocity_ptr_;
+        const size_t n = field_total_size_;
+        #pragma omp target teams distribute parallel for \
+            map(present: div_ptr[0:n]) reduction(max:max_div)
+        for (size_t idx = 0; idx < n; ++idx) {
+            double v = div_ptr[idx];
+            double av = (v >= 0.0) ? v : -v;
+            if (av > max_div) max_div = av;
+        }
+#else
         if (mesh_->is2D()) {
             for (int j = mesh_->j_begin(); j < mesh_->j_end(); ++j) {
                 for (int i = mesh_->i_begin(); i < mesh_->i_end(); ++i) {
                     max_div = std::max(max_div, std::abs(div_velocity_(i, j)));
                 }
             }
         } else {
             for (int k = mesh_->k_begin(); k < mesh_->k_end(); ++k) {
                 for (int j = mesh_->j_begin(); j < mesh_->j_end(); ++j) {
                     for (int i = mesh_->i_begin(); i < mesh_->i_end(); ++i) {
                         max_div = std::max(max_div, std::abs(div_velocity_(i, j, k)));
                     }
                 }
             }
         }
+#endif
+
         std::cout << "[Projection] max|div(u)|=" << std::scientific << std::setprecision(6)
                   << max_div << " dt*max|div|=" << current_dt_ * max_div << "\n";
     }
 }
  • Apply / Chat
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies a bug in new diagnostic code where device data is not synced to the host before being read, leading to incorrect output.

Medium
Check kernel launch failures early

Add CUDA_CHECK(cudaPeekAtLastError()) after kernel launches to detect
asynchronous errors early, which is especially useful during CUDA graph capture.

src/mg_cuda_kernels.cpp [459-461]

 chebyshev_3d_kernel<<<grid, block, 0, stream>>>(
     u, f, tmp, Nx, Ny, Nz, Ng,
     inv_dx2, inv_dy2, inv_dz2, inv_coeff, omega);
+CUDA_CHECK(cudaPeekAtLastError());
  • Apply / Chat
Suggestion importance[1-10]: 6

__

Why: The suggestion correctly points out a good practice for robust CUDA programming by adding asynchronous error checks after kernel launches, which significantly improves debuggability.

Low
General
Normalize L2 norms by size

Normalize the L2 norm calculation by the number of cells to make absolute
tolerances grid-independent and more robust.

src/poisson_solver_multigrid.cpp [1029-1030]

+const double n_cells = static_cast<double>(Nx) * static_cast<double>(Ny) * static_cast<double>(is_2d ? 1 : Nz);
+
 r_inf = max_res;
-r_l2 = std::sqrt(sum_sq);
+r_l2 = std::sqrt(sum_sq / std::max(1.0, n_cells));
  • Apply / Chat
Suggestion importance[1-10]: 6

__

Why: The suggestion correctly points out that using a non-normalized L2 norm makes absolute tolerances grid-dependent, and proposes using an RMS norm for better robustness.

Low
  • Update

Previous suggestions

✅ Suggestions up to commit ae03607
CategorySuggestion                                                                                                                                    Impact
Possible issue
Eliminate boundary write races
Suggestion Impact:The commit adds early-return guards in the Y-low/Y-high faces to skip x-edge cells, and in the Z-low/Z-high faces to skip both x- and y-edge cells, matching the proposed write-precedence scheme to avoid double-writes. It also adjusts periodic wrap indices for x-, y-, and z-faces to use (N + Ng - 1) offsets as in the suggestion. (Other changes like Chebyshev constants are unrelated.)

code diff:

@@ -158,7 +168,7 @@
         int k = jk / (Ny + 2*Ng);
         int idx = k * plane_stride + j * stride + 0;
         int idx_int = k * plane_stride + j * stride + Ng;
-        int idx_wrap = k * plane_stride + j * stride + Nx;
+        int idx_wrap = k * plane_stride + j * stride + (Nx + Ng - 1);
 
         if (bc_x_lo == 2) { // Periodic
             u[idx] = u[idx_wrap];
@@ -191,14 +201,16 @@
     }
     remaining -= face_yz;
 
-    // Y-low face
+    // Y-low face (skip x-edges to avoid race with x-faces)
     if (remaining < face_xz) {
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
+        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
+        if (i == 0 || i == Nx + Ng) return;
         int idx = k * plane_stride + 0 * stride + i;
         int idx_int = k * plane_stride + Ng * stride + i;
-        int idx_wrap = k * plane_stride + Ny * stride + i;
+        int idx_wrap = k * plane_stride + (Ny + Ng - 1) * stride + i;
 
         if (bc_y_lo == 2) { // Periodic
             u[idx] = u[idx_wrap];
@@ -211,11 +223,13 @@
     }
     remaining -= face_xz;
 
-    // Y-high face
+    // Y-high face (skip x-edges to avoid race with x-faces)
     if (remaining < face_xz) {
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
+        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
+        if (i == 0 || i == Nx + Ng) return;
         int idx = k * plane_stride + (Ny + Ng) * stride + i;
         int idx_int = k * plane_stride + (Ny + Ng - 1) * stride + i;
         int idx_wrap = k * plane_stride + Ng * stride + i;
@@ -231,14 +245,16 @@
     }
     remaining -= face_xz;
 
-    // Z-low face
+    // Z-low face (skip x/y-edges to avoid races)
     if (remaining < face_xy) {
         int ij = remaining;
         int i = ij % (Nx + 2*Ng);
         int j = ij / (Nx + 2*Ng);
+        // Skip cells owned by x-faces or y-faces
+        if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
         int idx = 0 * plane_stride + j * stride + i;
         int idx_int = Ng * plane_stride + j * stride + i;
-        int idx_wrap = Nz * plane_stride + j * stride + i;
+        int idx_wrap = (Nz + Ng - 1) * plane_stride + j * stride + i;
 
         if (bc_z_lo == 2) { // Periodic
             u[idx] = u[idx_wrap];
@@ -251,11 +267,13 @@
     }
     remaining -= face_xy;
 
-    // Z-high face
+    // Z-high face (skip x/y-edges to avoid races)
     if (remaining < face_xy) {
         int ij = remaining;
         int i = ij % (Nx + 2*Ng);
         int j = ij / (Nx + 2*Ng);
+        // Skip cells owned by x-faces or y-faces
+        if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
         int idx = (Nz + Ng) * plane_stride + j * stride + i;
         int idx_int = (Nz + Ng - 1) * plane_stride + j * stride + i;
         int idx_wrap = Ng * plane_stride + j * stride + i;
@@ -603,11 +621,9 @@

Prevent data races in bc_3d_kernel by ensuring edge and corner ghost cells are
written to by only one thread. This is achieved by establishing a precedence
where x-faces handle their edges, and y-faces handle their remaining edges,
preventing concurrent writes.

src/mg_cuda_kernels.cpp [194-251]

 // Y-low face
 if (remaining < face_xz) {
     int ik = remaining;
     int i = ik % (Nx + 2*Ng);
     int k = ik / (Nx + 2*Ng);
+
+    // Avoid double-writes: x-faces own x-edges/corners
+    if (i == 0 || i == (Nx + Ng)) return;
+
     int idx = k * plane_stride + 0 * stride + i;
     int idx_int = k * plane_stride + Ng * stride + i;
-    int idx_wrap = k * plane_stride + Ny * stride + i;
+    int idx_wrap = k * plane_stride + (Ny + Ng - 1) * stride + i;
 
     if (bc_y_lo == 2) { // Periodic
         u[idx] = u[idx_wrap];
     } else if (bc_y_lo == 1) { // Neumann
         u[idx] = u[idx_int];
     } else { // Dirichlet
         u[idx] = 2.0 * dirichlet_val - u[idx_int];
     }
     return;
 }
 remaining -= face_xz;
 
 // Z-low face
 if (remaining < face_xy) {
     int ij = remaining;
     int i = ij % (Nx + 2*Ng);
     int j = ij / (Nx + 2*Ng);
+
+    // Avoid double-writes: x-faces own x-edges/corners; y-faces own y-edges/corners
+    if (i == 0 || i == (Nx + Ng) || j == 0 || j == (Ny + Ng)) return;
+
     int idx = 0 * plane_stride + j * stride + i;
     int idx_int = Ng * plane_stride + j * stride + i;
-    int idx_wrap = Nz * plane_stride + j * stride + i;
+    int idx_wrap = (Nz + Ng - 1) * plane_stride + j * stride + i;
 
     if (bc_z_lo == 2) { // Periodic
         u[idx] = u[idx_wrap];
     } else if (bc_z_lo == 1) { // Neumann
         u[idx] = u[idx_int];
     } else { // Dirichlet
         u[idx] = 2.0 * dirichlet_val - u[idx_int];
     }
     return;
 }
Suggestion importance[1-10]: 9

__

Why: This suggestion correctly identifies a critical data race condition in the bc_3d_kernel where threads assigned to different faces can write to the same edge or corner ghost cells simultaneously. The proposed fix correctly establishes a write precedence, ensuring each boundary cell is updated by exactly one thread, which is crucial for correctness and determinism.

High
Avoid unintended device reallocation

Replace map(present, alloc: ...) with map(present: ...) for solver-managed
device buffers to enforce the contract that data must already be mapped and
avoid potential data corruption. This change should be applied to all similar
instances in the PR.

src/poisson_solver_fft.cpp [609-610]

 #pragma omp target teams distribute parallel for collapse(3) reduction(+:sum) \
-    map(present, alloc: rhs_ptr[0:total_size]) is_device_ptr(packed)
+    map(present: rhs_ptr[0:total_size]) is_device_ptr(packed)
Suggestion importance[1-10]: 9

__

Why: This suggestion correctly identifies a critical and subtle correctness issue introduced by systematically replacing map(present: ...) with map(present, alloc: ...), which could lead to silent data inconsistencies on the GPU.

High
Reset the correct device buffer

In the benchmark, use map(present: p_ptr[...]) instead of map(present, alloc:
p_ptr[...]) in the reset kernels. This ensures the zeroing operation targets the
correct, already-mapped device buffer.

tests/bench_mg_cuda_graphs.cpp [71-80]

 #pragma omp target enter data map(to: rhs_ptr[0:total_size], p_ptr[0:total_size])
 
 // Warmup (2 solves)
 for (int w = 0; w < 2; ++w) {
-    #pragma omp target teams distribute parallel for map(present, alloc: p_ptr[0:total_size])
+    #pragma omp target teams distribute parallel for map(present: p_ptr[0:total_size])
     for (size_t idx = 0; idx < total_size; ++idx) {
         p_ptr[idx] = 0.0;
     }
     mg.solve_device(rhs_ptr, p_ptr, cfg);
 }
Suggestion importance[1-10]: 9

__

Why: This suggestion correctly identifies a critical bug in the new benchmark code where using map(present, alloc: ...) could cause the benchmark to run on uninitialized data, invalidating its results.

High
Fix periodic wrap indexing
Suggestion Impact:Updated the periodic wrap index on the X-low face from Nx to (Nx + Ng - 1), matching the suggested fix for Ng>1. The commit also applied the same Ng-aware wrap correction to Y-low and Z-low periodic faces (Ny/Nz -> Ny+Ng-1 / Nz+Ng-1) and added edge-skipping logic to avoid race conditions, extending beyond the original suggestion.

code diff:

@@ -158,7 +168,7 @@
         int k = jk / (Ny + 2*Ng);
         int idx = k * plane_stride + j * stride + 0;
         int idx_int = k * plane_stride + j * stride + Ng;
-        int idx_wrap = k * plane_stride + j * stride + Nx;
+        int idx_wrap = k * plane_stride + j * stride + (Nx + Ng - 1);
 
         if (bc_x_lo == 2) { // Periodic
             u[idx] = u[idx_wrap];
@@ -191,14 +201,16 @@
     }
     remaining -= face_yz;
 
-    // Y-low face
+    // Y-low face (skip x-edges to avoid race with x-faces)
     if (remaining < face_xz) {
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
+        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
+        if (i == 0 || i == Nx + Ng) return;
         int idx = k * plane_stride + 0 * stride + i;
         int idx_int = k * plane_stride + Ng * stride + i;
-        int idx_wrap = k * plane_stride + Ny * stride + i;
+        int idx_wrap = k * plane_stride + (Ny + Ng - 1) * stride + i;
 
         if (bc_y_lo == 2) { // Periodic
             u[idx] = u[idx_wrap];
@@ -211,11 +223,13 @@
     }
     remaining -= face_xz;
 
-    // Y-high face
+    // Y-high face (skip x-edges to avoid race with x-faces)
     if (remaining < face_xz) {
         int ik = remaining;
         int i = ik % (Nx + 2*Ng);
         int k = ik / (Nx + 2*Ng);
+        // Skip cells owned by x-faces (i=0 or i=Nx+Ng)
+        if (i == 0 || i == Nx + Ng) return;
         int idx = k * plane_stride + (Ny + Ng) * stride + i;
         int idx_int = k * plane_stride + (Ny + Ng - 1) * stride + i;
         int idx_wrap = k * plane_stride + Ng * stride + i;
@@ -231,14 +245,16 @@
     }
     remaining -= face_xz;
 
-    // Z-low face
+    // Z-low face (skip x/y-edges to avoid races)
     if (remaining < face_xy) {
         int ij = remaining;
         int i = ij % (Nx + 2*Ng);
         int j = ij / (Nx + 2*Ng);
+        // Skip cells owned by x-faces or y-faces
+        if (i == 0 || i == Nx + Ng || j == 0 || j == Ny + Ng) return;
         int idx = 0 * plane_stride + j * stride + i;
         int idx_int = Ng * plane_stride + j * stride + i;
-        int idx_wrap = Nz * plane_stride + j * stride + i;
+        int idx_wrap = (Nz + Ng - 1) * plane_stride + j * stride + i;
 

Correct the periodic boundary condition indexing in bc_3d_kernel to support more
than one ghost cell (Ng > 1). The current implementation assumes Ng=1, which
could lead to incorrect behavior if this changes.

src/mg_cuda_kernels.cpp [154-171]

 // X-low face
 if (remaining < face_yz) {
     int jk = remaining;
     int j = jk % (Ny + 2*Ng);
     int k = jk / (Ny + 2*Ng);
     int idx = k * plane_stride + j * stride + 0;
     int idx_int = k * plane_stride + j * stride + Ng;
-    int idx_wrap = k * plane_stride + j * stride + Nx;
+    int idx_wrap = k * plane_stride + j * stride + (Nx + Ng - 1);
 
     if (bc_x_lo == 2) { // Periodic
         u[idx] = u[idx_wrap];
     } else if (bc_x_lo == 1) { // Neumann
         u[idx] = u[idx_int];
     } else { // Dirichlet
         u[idx] = 2.0 * dirichlet_val - u[idx_int];
     }
     return;
 }
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a bug in the periodic boundary condition logic for bc_3d_kernel that would cause incorrect behavior if Ng > 1. While the current codebase uses Ng=1, this change makes the kernel more robust and prevents future errors.

Medium
Handle unmapped pointer sentinel
Suggestion Impact:The commit modified get_device_ptr to be more defensive (returns nullptr when host_ptr is nullptr) and clarified behavior for unmapped pointers, but it did not add the suggested omp_map_fail check; it instead assumes unmapped pointers yield nullptr.

code diff:

-/// Returns nullptr if pointer is not mapped
+/// Returns nullptr if pointer is not mapped or host_ptr is nullptr
 template<typename T>
 inline T* get_device_ptr(T* host_ptr) {
+    if (host_ptr == nullptr) return nullptr;
     int device = omp_get_default_device();
     void* dev_ptr = omp_get_mapped_ptr(host_ptr, device);
+    // omp_get_mapped_ptr returns nullptr if pointer is not mapped
     return static_cast<T*>(dev_ptr);
 }

In get_device_ptr, check if omp_get_mapped_ptr returns omp_map_fail. If it does,
return nullptr to prevent using an invalid device pointer.

include/gpu_utils.hpp [248-253]

 template<typename T>
 inline T* get_device_ptr(T* host_ptr) {
     int device = omp_get_default_device();
     void* dev_ptr = omp_get_mapped_ptr(host_ptr, device);
+    if (dev_ptr == omp_map_fail) return nullptr;
     return static_cast<T*>(dev_ptr);
 }
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly points out that not checking for omp_map_fail can lead to hard-to-debug errors from invalid device pointers, and the fix makes the function more robust and aligned with its documentation.

Medium
Prevent reuse of destroyed graph
Suggestion Impact:After calling vcycle_graph_->destroy() when boundary conditions change, the commit now also calls vcycle_graph_.reset(), ensuring the destroyed graph cannot be accidentally reused and will be recaptured later.

code diff:

 #ifdef USE_GPU_OFFLOAD
-    // Re-initialize CUDA Graphs with new BCs
-    if (use_cuda_graphs_) {
-        initialize_cuda_graphs();
-    }
     // Invalidate V-cycle graph so it gets recaptured with new BCs
     if (vcycle_graph_) {
         vcycle_graph_->destroy();
+        vcycle_graph_.reset();
     }
 #endif
 }
@@ -112,13 +109,10 @@
     bc_z_hi_ = z_hi;
 
 #ifdef USE_GPU_OFFLOAD
-    // Re-initialize CUDA Graphs with new BCs
-    if (use_cuda_graphs_) {
-        initialize_cuda_graphs();
-    }
     // Invalidate V-cycle graph so it gets recaptured with new BCs
     if (vcycle_graph_) {
         vcycle_graph_->destroy();
+        vcycle_graph_.reset();
     }
 #endif
 }

Reset the vcycle_graph_ smart pointer to nullptr after calling destroy() to
prevent accidental reuse of the destroyed graph object.

src/poisson_solver_multigrid.cpp [92-101]

 #ifdef USE_GPU_OFFLOAD
     // Re-initialize CUDA Graphs with new BCs
     if (use_cuda_graphs_) {
         initialize_cuda_graphs();
     }
     // Invalidate V-cycle graph so it gets recaptured with new BCs
     if (vcycle_graph_) {
         vcycle_graph_->destroy();
+        vcycle_graph_.reset();
     }
 #endif
Suggestion importance[1-10]: 7

__

Why: This is a good defensive programming suggestion that improves robustness by resetting the vcycle_graph_ smart pointer after destruction, preventing potential use-after-free bugs.

Medium
Remove non-portable pi constant

Replace the non-standard M_PI macro with a constexpr double constant to improve
code portability across different compilers and platforms.

src/mg_cuda_kernels.cpp [621-624]

+constexpr double kPi = 3.141592653589793238462643383279502884;
 for (int k = 0; k < degree_; ++k) {
     // Chebyshev-optimal weight
-    double theta = M_PI * (2.0 * k + 1.0) / (2.0 * degree_);
+    double theta = kPi * (2.0 * k + 1.0) / (2.0 * degree_);
     double omega = 1.0 / (d - c * std::cos(theta));
Suggestion importance[1-10]: 4

__

Why: The suggestion correctly points out that M_PI is non-standard and could cause portability issues. Replacing it with a constexpr constant is a good practice for improving code portability and adhering to the C++ standard.

Low
Incremental [*]
Prevent loop index overflow

Change the loop index variable from int to size_t and remove the cast from
size_c to prevent potential overflow when zeroing the coarse grid solution on
the GPU.

src/poisson_solver_multigrid.cpp [1458-1462]

 #pragma omp target teams distribute parallel for \
     map(present: u_coarse[0:size_c])
-for (int idx = 0; idx < (int)size_c; ++idx) {
+for (size_t idx = 0; idx < size_c; ++idx) {
     u_coarse[idx] = 0.0;
 }
Suggestion importance[1-10]: 7

__

Why: This is a valid correctness fix that prevents a potential integer overflow on large problem sizes, which could lead to incorrect behavior of the multigrid solver.

Medium
Make reductions GPU-portable

Replace std::max and std::abs with manual arithmetic (ternary operator for abs,
if statement for max) inside the OpenMP target region to ensure GPU portability.

src/poisson_solver_multigrid.cpp [2059-2068]

 #pragma omp target teams distribute parallel for collapse(2) \
     map(present: f_dev[0:total_size]) reduction(max: b_inf_local) reduction(+: b_sum_sq)
 for (int j = Ng; j < Ny_g + Ng; ++j) {
     for (int i = Ng; i < Nx_g + Ng; ++i) {
         int idx = j * stride_gpu + i;
         double val = f_dev[idx];
-        b_inf_local = std::max(b_inf_local, std::abs(val));
+        double abs_val = (val >= 0.0) ? val : -val;
+        if (abs_val > b_inf_local) b_inf_local = abs_val;
         b_sum_sq += val * val;
     }
 }
Suggestion importance[1-10]: 6

__

Why: This suggestion improves portability and robustness by replacing std::max and std::abs with device-friendly arithmetic, which is good practice for GPU code.

Low
Suggestions up to commit 55bba54
CategorySuggestion                                                                                                                                    Impact
Possible issue
Prevent data race on static warning flag

Replace the static bool warned flag with std::call_once and std::once_flag to
prevent a data race when vcycle_graphed() is called concurrently from multiple
threads.

src/poisson_solver_multigrid.cpp [2422-2438]

 cudaStream_t omp_stream = reinterpret_cast<cudaStream_t>(
     ompx_get_cuda_stream(omp_get_default_device(), /*sync=*/0));
 
 // Runtime stream validation with graceful fallback
 // A null stream would launch on CUDA default stream, causing potential race conditions
 // with OpenMP target regions. Instead of crashing in production, fall back to the
 // non-graphed path which is slower but correct.
 if (omp_stream == nullptr) {
-    static bool warned = false;
-    if (!warned) {
+    static std::once_flag warned_flag;
+    std::call_once(warned_flag, [](){
         std::cerr << "[MG] WARNING: OpenMP CUDA stream is null - falling back to non-graphed V-cycle\n"
                   << "    This may indicate a runtime issue. Performance will be degraded.\n";
-        warned = true;
-    }
+    });
     vcycle(0, vcycle_graph_nu1_, vcycle_graph_nu2_, vcycle_graph_degree_);
     return;
 }
Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies a data race on the static warned flag in a multi-threaded context and proposes a robust, thread-safe solution using std::call_once, which is a critical correctness fix.

High
Fix performance bug in diagnostics

To fix a performance issue, move the post-projection divergence check from the
CPU to the GPU using an OpenMP target directive with a reduction(max:...)
clause.

src/solver.cpp [3427-3448]

 if (div_diagnostics && (iter_ % div_diagnostics_interval == 0)) {
     compute_divergence(VelocityWhich::Current, div_velocity_);  // Divergence of corrected velocity
     double max_div = 0.0;
-    if (mesh_->is2D()) {
-        for (int j = mesh_->j_begin(); j < mesh_->j_end(); ++j) {
-            for (int i = mesh_->i_begin(); i < mesh_->i_end(); ++i) {
-                max_div = std::max(max_div, std::abs(div_velocity_(i, j)));
-            }
-        }
-    } else {
-        for (int k = mesh_->k_begin(); k < mesh_->k_end(); ++k) {
-            for (int j = mesh_->j_begin(); j < mesh_->j_end(); ++j) {
-                for (int i = mesh_->i_begin(); i < mesh_->i_end(); ++i) {
-                    max_div = std::max(max_div, std::abs(div_velocity_(i, j, k)));
-                }
+    const double* div_ptr = div_velocity_.data().data();
+    const int i_begin = mesh_->i_begin(), i_end = mesh_->i_end();
+    const int j_begin = mesh_->j_begin(), j_end = mesh_->j_end();
+    const int k_begin = mesh_->k_begin(), k_end = mesh_->k_end();
+    const int stride = div_velocity_.stride();
+    const int plane_stride = div_velocity_.plane_stride();
+
+    #pragma omp target teams distribute parallel for reduction(max:max_div) collapse(3) map(tofrom:max_div)
+    for (int k = k_begin; k < k_end; ++k) {
+        for (int j = j_begin; j < j_end; ++j) {
+            for (int i = i_begin; i < i_end; ++i) {
+                const double val = div_ptr[k * plane_stride + j * stride + i];
+                max_div = std::max(max_div, std::abs(val));
             }
         }
     }
     std::cout << "[Projection] max|div(u)|=" << std::scientific << std::setprecision(6)
               << max_div << " dt*max|div|=" << current_dt_ * max_div << "\n";
 }
Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies a major performance bottleneck where host-side loops access device data element-by-element, and proposes an efficient GPU-side reduction which is critical for performance when diagnostics are enabled.

High
Guard NVHPC stream API

Wrap the call to the NVHPC-specific function ompx_get_cuda_stream in an
__NVCOMPILER preprocessor guard to ensure portability and prevent compilation
errors with other compilers.

src/poisson_solver_multigrid.cpp [2422-2423]

-    cudaStream_t omp_stream = reinterpret_cast<cudaStream_t>(
-        ompx_get_cuda_stream(omp_get_default_device(), /*sync=*/0));
+  #ifdef __NVCOMPILER
+      cudaStream_t omp_stream = static_cast<cudaStream_t>(
+          ompx_get_cuda_stream(omp_get_default_device(), /*sync=*/0));
+  #else
+      cudaStream_t omp_stream = nullptr;
+  #endif
Suggestion importance[1-10]: 8

__

Why: This suggestion correctly identifies a portability issue where the code would fail to compile with non-NVHPC compilers and provides a standard solution using preprocessor guards, which is important for code robustness.

Medium
Fix flawed logic in parity test

Fix the flawed logic in test_convergence_parity by adding a mechanism to
explicitly disable CUDA graph usage for one of the solver instances, ensuring
the test correctly compares the graphed and non-graphed code paths.

tests/test_vcycle_graph_stress.cpp [150-238]

 bool test_convergence_parity() {
     ...
-    // Test with graphed path (MG_USE_VCYCLE_GRAPH=1)
+    // Test with graphed path (default)
     MultigridPoissonSolver mg_graphed(mesh);
     mg_graphed.set_bc(PoissonBC::Periodic, PoissonBC::Periodic,
                        PoissonBC::Periodic, PoissonBC::Periodic,
                        PoissonBC::Periodic, PoissonBC::Periodic);
 
-    // Test with non-graphed path
+    // Test with non-graphed path (explicitly disabled)
     MultigridPoissonSolver mg_nongraphed(mesh);
     mg_nongraphed.set_bc(PoissonBC::Periodic, PoissonBC::Periodic,
                           PoissonBC::Periodic, PoissonBC::Periodic,
                           PoissonBC::Periodic, PoissonBC::Periodic);
+    mg_nongraphed.disable_vcycle_graph(true); // Force non-graphed path
 
     ...
     for (int cycles = 1; cycles <= 8; ++cycles) {
         ...
 #ifdef USE_GPU_OFFLOAD
         mg_graphed.solve_device(rhs_ptr, pg_ptr, cfg);
         mg_nongraphed.solve_device(rhs_ptr, png_ptr, cfg);
         #pragma omp target update from(pg_ptr[0:total_size], png_ptr[0:total_size])
 #else
         ...
 #endif
         ...
     }
     ...
 }
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a fundamental flaw in the test logic, where both test subjects follow the same code path, rendering the comparison invalid. The proposed fix is necessary for the test to be meaningful.

Medium
High-level
Simplify solver logic by separating execution modes

Refactor the monolithic solve and solve_device methods by separating the
distinct solver modes (fixed-cycle, adaptive, convergence-based) and execution
paths (OpenMP vs. CUDA Graph) into smaller, dedicated functions to improve
maintainability.

Examples:

src/poisson_solver_multigrid.cpp [1670-1852]
    assert(gpu_ready_ && "GPU must be initialized");
    sync_level_to_gpu(0);
#endif

    apply_bc(0);

    // ========================================================================
    // Fixed-cycle mode: run exactly N V-cycles without convergence checks
    // This is the fastest mode for projection - no D→H transfers mid-solve
    // ========================================================================

 ... (clipped 173 lines)
src/poisson_solver_multigrid.cpp [1859-2149]
    
    // Device-resident solve using Model 1 (host pointer + present mapping)
    // Parameters are host pointers that caller has already mapped via `target enter data`.
    // We use map(present, alloc: ...) to access the device copies without additional transfers.
    
    auto& finest = *levels_[0];
    const int Nx = finest.Nx;
    const int Ny = finest.Ny;
    const int Nz = finest.Nz;
    // Total size includes ghost cells: (Nx+2)*(Ny+2)*(Nz+2) for 3D, or *3 for 2D (Nz=1)

 ... (clipped 281 lines)

Solution Walkthrough:

Before:

int MultigridPoissonSolver::solve_device(...) {
    // ... setup ...
    if (cfg.fixed_cycles > 0) {
        // Fixed-cycle mode logic
        const bool use_graph = use_vcycle_graph_ && cfg.use_vcycle_graph;
        auto run_cycles = [&](int n) {
            if (use_graph) {
                // ... graph initialization and execution ...
            } else {
                // ... standard v-cycle execution ...
            }
        };

        if (cfg.adaptive_cycles) {
            // Adaptive sub-mode logic with convergence checks
            run_cycles(...);
            // ... check residual and run more cycles ...
        } else {
            // Pure fixed mode logic
            run_cycles(max_cycles);
        }
        // ... finalization ...
    } else {
        // Convergence-based mode logic
        for (int cycle = 0; cycle < max_cycles; ++cycle) {
            vcycle(...);
            if (/* check interval */) {
                // ... check convergence ...
                if (converged) break;
            }
        }
        // ... finalization ...
    }
    return cycles_used;
}

After:

int MultigridPoissonSolver::solve_device(...) {
    // ... common setup ...
    if (cfg.fixed_cycles > 0) {
        if (cfg.adaptive_cycles) {
            return solve_device_fixed_adaptive(cfg);
        } else {
            return solve_device_fixed(cfg);
        }
    } else {
        return solve_device_converged(cfg);
    }
}

int MultigridPoissonSolver::solve_device_fixed(const PoissonConfig& cfg) {
    // Logic for pure fixed-cycle mode (with graph/no-graph paths)
    // ...
    return cycles_run;
}

int MultigridPoissonSolver::solve_device_fixed_adaptive(const PoissonConfig& cfg) {
    // Logic for adaptive fixed-cycle mode
    // ...
    return cycles_run;
}

int MultigridPoissonSolver::solve_device_converged(const PoissonConfig& cfg) {
    // Logic for convergence-based mode
    // ...
    return cycles_used;
}
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies that the solve and solve_device methods have become monolithic and complex by handling multiple execution modes, which significantly hurts maintainability and readability.

Medium
General
Synchronize specific stream instead of entire device

Replace the device-wide cudaDeviceSynchronize() with a more efficient
cudaStreamSynchronize() on the specific OpenMP stream to avoid unnecessary
performance stalls.

src/poisson_solver_multigrid.cpp [1937-1940]

 // CRITICAL: Sync all device work before OpenMP target reduction
 // Both CUDA graph and OpenMP target regions may use different streams.
-// DeviceSynchronize ensures all async GPU work completes before reduction.
-cudaDeviceSynchronize();
+// Synchronize the specific stream used for the graph to ensure its work is complete.
+cudaStream_t omp_stream = reinterpret_cast<cudaStream_t>(ompx_get_cuda_stream(omp_get_default_device(), 0));
+if (omp_stream) {
+    cudaStreamSynchronize(omp_stream);
+} else {
+    // Fallback for safety if the stream is null
+    cudaDeviceSynchronize();
+}
Suggestion importance[1-10]: 8

__

Why: This is a significant performance optimization that replaces a blocking device-wide synchronization with a more targeted stream-specific one, which is crucial in a high-performance GPU context.

Medium
Refactor boundary kernel for better performance

Refactor the bc_3d_kernel into three separate, simpler kernels for the X, Y, and
Z faces, each launched with a 2D grid. This will improve performance and
maintainability by simplifying logic and reducing thread divergence.

src/mg_cuda_kernels.cpp [129-272]

-__global__ void bc_3d_kernel(
-    double* __restrict__ u,
-    int Nx, int Ny, int Nz, int Ng,
-    int bc_x_lo, int bc_x_hi,
-    int bc_y_lo, int bc_y_hi,
-    int bc_z_lo, int bc_z_hi,
-    double dirichlet_val)
-{
+__global__ void bc_3d_x_faces_kernel(double* u, int Nx, int Ny, int Nz, int Ng, int bc_lo, int bc_hi, double val) {
+    int j = blockIdx.x * blockDim.x + threadIdx.x;
+    int k = blockIdx.y * blockDim.y + threadIdx.y;
+    if (j >= Ny + 2 * Ng || k >= Nz + 2 * Ng) return;
+
     int stride = Nx + 2 * Ng;
     int plane_stride = stride * (Ny + 2 * Ng);
 
-    // Thread covers all boundary cells
-    int tid = blockIdx.x * blockDim.x + threadIdx.x;
-    int total_boundary = (Ny + 2*Ng) * (Nz + 2*Ng) * 2 +  // x faces
-                         (Nx + 2*Ng) * (Nz + 2*Ng) * 2 +  // y faces
-                         (Nx + 2*Ng) * (Ny + 2*Ng) * 2;   // z faces
-
-    if (tid >= total_boundary) return;
-
-    // Decode which boundary face and which cell
-    int remaining = tid;
-    int face_yz = (Ny + 2*Ng) * (Nz + 2*Ng);
-    int face_xz = (Nx + 2*Ng) * (Nz + 2*Ng);
-    int face_xy = (Nx + 2*Ng) * (Ny + 2*Ng);
-
     // X-low face
-    if (remaining < face_yz) {
-        ...
-        return;
-    }
-    remaining -= face_yz;
+    int idx_lo = k * plane_stride + j * stride + 0;
+    int idx_int_lo = k * plane_stride + j * stride + Ng;
+    if (bc_lo == 2) u[idx_lo] = u[k * plane_stride + j * stride + Nx]; // Periodic
+    else if (bc_lo == 1) u[idx_lo] = u[idx_int_lo]; // Neumann
+    else u[idx_lo] = 2.0 * val - u[idx_int_lo]; // Dirichlet
 
     // X-high face
-    if (remaining < face_yz) {
-        ...
-        return;
-    }
-    remaining -= face_yz;
-
-    // Y-low face
-    ...
+    int idx_hi = k * plane_stride + j * stride + (Nx + Ng);
+    int idx_int_hi = k * plane_stride + j * stride + (Nx + Ng - 1);
+    if (bc_hi == 2) u[idx_hi] = u[k * plane_stride + j * stride + Ng]; // Periodic
+    else if (bc_hi == 1) u[idx_hi] = u[idx_int_hi]; // Neumann
+    else u[idx_hi] = 2.0 * val - u[idx_int_hi]; // Dirichlet
 }
 
+// Similar kernels for Y (bc_3d_y_faces_kernel) and Z (bc_3d_z_faces_kernel) faces would be implemented.
+// The launch function launch_bc_3d would be updated to launch these three kernels.
+
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies that the single "mega-kernel" for boundary conditions is suboptimal for performance and maintainability, and proposes a valid refactoring into three separate, more efficient kernels.

Medium
Use tolerance for floating comparisons

Use a tolerance for floating-point comparisons in
VCycleGraphFingerprint::operator== to prevent unnecessary CUDA graph recaptures
due to minor rounding differences in grid spacing values.

include/mg_cuda_kernels.hpp [251-264]

 bool operator==(const VCycleGraphFingerprint& other) const {
-    return num_levels == other.num_levels &&
-           level_sizes == other.level_sizes &&
-           level_coeffs == other.level_coeffs &&
-           level_dx == other.level_dx &&
-           level_dy == other.level_dy &&
-           level_dz == other.level_dz &&
-           degree == other.degree &&
+    if (num_levels != other.num_levels || level_sizes != other.level_sizes) return false;
+    auto close = [](double a, double b) { return std::abs(a - b) < 1e-12; };
+    for (size_t i = 0; i < level_coeffs.size(); ++i) {
+        if (!close(level_coeffs[i], other.level_coeffs[i])) return false;
+    }
+    for (size_t i = 0; i < level_dx.size(); ++i) {
+        if (!close(level_dx[i], other.level_dx[i]) ||
+            !close(level_dy[i], other.level_dy[i]) ||
+            !close(level_dz[i], other.level_dz[i])) {
+            return false;
+        }
+    }
+    return degree == other.degree &&
            nu1 == other.nu1 && nu2 == other.nu2 &&
            bc_x_lo == other.bc_x_lo && bc_x_hi == other.bc_x_hi &&
            bc_y_lo == other.bc_y_lo && bc_y_hi == other.bc_y_hi &&
            bc_z_lo == other.bc_z_lo && bc_z_hi == other.bc_z_hi &&
            coarse_iters == other.coarse_iters;
 }
Suggestion importance[1-10]: 6

__

Why: This suggestion correctly identifies that exact floating-point comparisons can be brittle and cause unnecessary CUDA graph recaptures, improving the robustness of the performance optimization.

Low
Use std::abs for clarity and performance

Replace the ternary operator (r >= 0.0) ? r : -r with std::abs(r) to improve
code clarity and leverage potential compiler optimizations for calculating the
absolute value.

src/poisson_solver_multigrid.cpp [1155-1156]

-double abs_r = (r >= 0.0) ? r : -r;
+double abs_r = std::abs(r);
 if (abs_r > max_res) max_res = abs_r;
Suggestion importance[1-10]: 3

__

Why: The suggestion correctly points out that std::abs() is more idiomatic and readable, which is a valid but minor improvement for code maintainability.

Low

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds robust multigrid convergence with relative tolerance criteria to the NNCFD Poisson solver. It introduces multiple convergence modes (absolute, RHS-relative, initial-residual-relative), CUDA Graph support for V-cycle operations, and comprehensive benchmarking infrastructure.

Key changes:

  • Adds robust convergence criteria with three tolerance modes for the multigrid solver
  • Introduces fixed-cycle and adaptive-cycle modes for projection steps
  • Implements CUDA Graph acceleration for V-cycle operations to reduce kernel launch overhead
  • Updates OpenMP target map clauses from map(present:) to map(present, alloc:) for better GPU offload semantics
  • Adds comprehensive test suite and benchmarking tools for MG parameter tuning

Reviewed changes

Copilot reviewed 30 out of 30 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/test_vcycle_graph_stress.cpp New stress tests for V-cycle graph: BC alternation, convergence parity, anisotropic grids
tests/test_mg_physics_match.cpp New test comparing fixed-cycle vs converged MG physics
tests/bench_mg_tuning.cpp New benchmark for MG parameter tuning with quality metrics
tests/bench_mg_cuda_graphs.cpp New benchmark for CUDA graph performance comparison
tests/bench_mg_bc_sweep.cpp New benchmark for BC robustness testing
tests/bench_fft_vs_mg.cpp New benchmark comparing FFT and MG solvers
tests/bench_256.cpp Simple 256³ grid benchmark
tests/test_poisson_unified.cpp Updated tolerance parameters for backward compatibility
src/poisson_solver_multigrid.cpp Major refactoring with new convergence modes and CUDA graph support
src/solver.cpp Integration updates and map clause changes throughout
src/mg_cuda_kernels.cpp New CUDA kernel implementations for MG operations
include/poisson_solver.hpp New configuration parameters for robust convergence
src/config.cpp Configuration loading for new parameters
scripts/run_nsys_profiles.sh New profiling script
scripts/ci.sh CI integration for new tests

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
src/solver.cpp (2)

4867-4925: Critical: duplicate OpenMP mappings for velocity_old_* (map(to) then map(alloc)).
You target enter data map(to: velocity_old_u_ptr_[...]) and later again map(alloc: velocity_old_u_ptr_[...]) (same for v/w). That’s a high risk for runtime mapping errors or undefined behavior. Pick one. Since you want device-authoritative old-velocity, keep only alloc + explicit init/copies.

Proposed fix (remove the earlier `map(to:)` for old-velocity)
-    #pragma omp target enter data map(to: velocity_old_u_ptr_[0:u_total_size])
-    #pragma omp target enter data map(to: velocity_old_v_ptr_[0:v_total_size])
@@
-    if (!mesh_->is2D()) {
-        const size_t w_total_size = velocity_.w_total_size();
-        #pragma omp target enter data map(to: velocity_old_w_ptr_[0:w_total_size])
-    }
+    // old velocity is device-authoritative: allocate on device only (no initial H->D copy)
@@
     #pragma omp target enter data map(alloc: velocity_old_u_ptr_[0:u_total_size])
     #pragma omp target enter data map(alloc: velocity_old_v_ptr_[0:v_total_size])
@@
     if (!mesh_->is2D()) {
         const size_t w_total_size = velocity_.w_total_size();
         #pragma omp target enter data map(alloc: velocity_old_w_ptr_[0:w_total_size])
     }

1651-1743: Use map(present: ...) for read-only inputs; data is guaranteed present via initialize_gpu_buffers(). The map(present, alloc: ...) pattern is redundant here—u_ptr and v_ptr are guaranteed to exist in the device data environment (mapped in initialize_gpu_buffers()), so the alloc fallback is unreachable and adds confusion. Use map(present: ...) alone for clarity. Additionally, per coding guidelines, prefer OmpDeviceBuffer wrapper for GPU buffer management instead of manual pointer + target enter/exit data pattern.

Also applies to: 1964-2044, 2089-2166

include/config.hpp (1)

83-149: Confirm behavior-changing defaults for operational stability.

The defaults poisson_fixed_cycles=8 and poisson_use_vcycle_graph=true switch MG solver behavior:

  • poisson_fixed_cycles: Changes from convergence-based (0) to fixed V-cycle mode (8), bypassing tolerance checks unless poisson_adaptive_cycles=true. Existing configs using convergence mode will silently flip behavior.
  • poisson_use_vcycle_graph: Enables GPU CUDA graph optimization by default, affecting only GPU targets.

Precedence is clear but asymmetric:

  • MG_USE_VCYCLE_GRAPH environment variable (lines 70–76, poisson_solver_multigrid.cpp) can disable VCYCLE_GRAPH via "0" or "false", but cannot override a disabled config value (AND logic at line 1903).
  • poisson_fixed_cycles has no environment variable override—only config file can change it.

Missing safeguards:

  • No deprecation warning or backward-compat guidance when these new defaults change existing run behavior.
  • Users must actively set poisson_fixed_cycles=0 in config files to restore convergence-based behavior.

Confirm this is intentional, document the required config file overrides for existing workflows, and consider adding a startup warning when poisson_fixed_cycles > 0 (fixed-cycle mode active).

include/poisson_solver_multigrid.hpp (1)

48-61: Fix solve_device docs: size expression is wrong for 3D.
The comment says Nx+2 * Ny+2 but the implementation uses (Nx+2)*(Ny+2)*(Nz+2).

Proposed fix
@@
-    /// @param rhs_present Host pointer to RHS array (must be present-mapped, size = Nx+2 * Ny+2)
-    /// @param p_present Host pointer to solution array (must be present-mapped, size = Nx+2 * Ny+2)
+    /// @param rhs_present Host pointer to RHS array (must be present-mapped, size = (Nx+2)*(Ny+2)*(Nz+2))
+    /// @param p_present Host pointer to solution array (must be present-mapped, size = (Nx+2)*(Ny+2)*(Nz+2))
src/poisson_solver_multigrid.cpp (1)

2151-2230: GPU buffer management isn’t exception-safe + violates the “use OmpDeviceBuffer” guidance.
If any level allocation fails mid-loop, earlier target enter data mappings and omp_target_alloc buffers leak (constructor throws, destructor won’t run). This is a real robustness problem on large grids / memory pressure. Prefer RAII (OmpDeviceBuffer or small local guard objects) so partial initialization unwinds cleanly.
Based on learnings, use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management.

🤖 Fix all issues with AI agents
In @app/profile_comprehensive.cpp:
- Around line 195-212: The timing code computes ms_per_step using (cfg.nsteps -
1) which will divide by zero when cfg.nsteps == 1; before computing elapsed_ms
and ms_per_step (around the high_resolution_clock block and variables
start/end/ms_per_step), add a guard that checks cfg.nsteps >= 2 and handle the
single-step case (e.g., skip per-step and throughput calculations or set
ms_per_step to elapsed_ms and throughput to 0 or print a suitable message) so
you never divide by (cfg.nsteps - 1).
- Around line 111-146: The code currently toggles MG graph mode via environment
variables in run_profile which conflicts with Config::poisson_use_vcycle_graph;
instead remove the setenv/unsetenv block and explicitly set
config.poisson_use_vcycle_graph = cfg.use_vcycle_graph so the solver respects
the Config flag (reference run_profile, Config and cfg.use_vcycle_graph, and
remove calls to setenv/unsetenv and MG_USE_VCYCLE_GRAPH).

In @include/gpu_utils.hpp:
- Around line 283-287: GPU_PARALLEL_FOR_ASYNC is only defined for GPU builds and
missing in the CPU fallback, causing compile errors; add a CPU-side definition
in the existing #else block that expands to a regular for loop with the same
signature (for (int var = start; var < end; ++var)) so code using the macro
compiles on CPU, preserve the comment about asynchronous behavior (and note that
gpu::sync() is a no-op or not required on CPU) and ensure the macro name and
parameter order exactly match the GPU version.
- Around line 245-253: Document that get_device_ptr(T* host_ptr) requires OpenMP
5.1 (omp_get_mapped_ptr) and that host_ptr must be an address inside an
OpenMP-mapped data region (i.e., memory allocated/registered with OpenMP mapping
directives); clarify that the function returns nullptr when the host pointer is
not mapped and callers must check for nullptr before using the returned device
pointer (especially before passing to CUDA kernels). Mark the function
[[nodiscard]] to prevent ignoring the result, add a brief comment why
mapping->device pointer conversion is needed (e.g., for passing device pointers
to GPU kernels), and add guidance to CMake/README to enforce a minimum OpenMP
version (or document minimum compiler versions like GCC/libgomp and Intel oneAPI
that support omp_get_mapped_ptr).

In @include/poisson_solver_multigrid.hpp:
- Around line 63-80: Update the public doc-comments for the new getter methods
to fully describe what each returns: specify the norm type (infinity norm or L2
norm), the time it is computed (e.g., computed at the end of solve() for
residual()/residual_l2(), computed at the start of solve() for
rhs_norm()/rhs_norm_l2()), the units (same units as solution/vector entries),
and the behavior in fixed-cycle mode (value is 0 when the quantity was not
computed). Apply this to residual(), residual_l2(), rhs_norm(), rhs_norm_l2(),
initial_residual(), and initial_residual_l2(), and ensure you clarify that
residual_l2()/rhs_norm_l2() are raw L2 norms (not normalized) and that "0 means
not computed" is explicitly documented.

In @profiles/taylor_green_fft_stats.txt:
- Around line 1-8: This generated profiling output contains absolute user/host
paths (e.g.
"/storage/home/hcoda1/6/sbryngelson3/scratch/nncfd/profiles/taylor_green_fft.sqlite"
and the nvtx_sum report) and must be removed from the repo; delete
profiles/taylor_green_fft_stats.txt from the commit, add a rule to ignore such
generated profiling outputs (e.g. add an entry to .gitignore for /profiles/*.txt
or your profiler output pattern), and replace it with a sanitized, stable
summary document (redacting usernames/host paths) or instructions on how to
regenerate the full report locally.

In @src/mg_cuda_kernels.cpp:
- Around line 10-23: Include the missing header by adding #include <string> so
the CUDA_CHECK macro can use std::string and std::to_string, and replace the
non-portable M_PI uses with a portable expression or constexpr (e.g., use
std::acos(-1.0) or define constexpr double PI = std::acos(-1.0)) where M_PI is
referenced (the occurrences flagged in this file), ensuring any functions or
macros that relied on M_PI now use the new PI symbol or std::acos(-1.0).
- Around line 127-272: bc_3d_kernel currently assigns threads to faces causing
write races on edge/corner ghost cells (multiple threads write the same u[idx]);
change the kernel to have each thread cover a unique ghost cell by mapping tid
over the full (Nx+2*Ng)*(Ny+2*Ng)*(Nz+2*Ng) grid and skipping interior cells,
compute i,j,k once and then deterministically apply boundary logic for that
single cell (check if i<Ng or i>=Nx+Ng, j<Ng or j>=Ny+Ng, k<Ng or k>=Nz+Ng)
resolving conflicts with a defined precedence (e.g., apply x-boundary then y
then z) or by composing periodic/Neumann/Dirichlet effects in that order so only
one thread writes each u[idx]; update indices (stride, plane_stride, idx,
idx_int, idx_wrap) accordingly and remove the per-face partitioning logic in
bc_3d_kernel.

In @src/poisson_solver_fft.cpp:
- Line 610: This code uses raw cudaMalloc/cudaMallocManaged and explicit
cudaFree for class-owned GPU buffers; replace each class-owned raw device
pointer with an OmpDeviceBuffer<T> member (allocate in the class constructor or
initialization path and rely on its destructor for cleanup), remove manual
cudaFree calls, and update any OpenMP map clauses so they no longer use alloc
for those members (pass the device pointer via is_device_ptr(buffer.data()) or
appropriate accessor). Keep existing map(present, alloc: rhs_ptr[...]) / p_ptr
behavior for externally owned parameters (rhs_ptr, p_ptr) unchanged. Ensure all
places that referenced the raw pointers (e.g., reads/writes, kernel launches,
FFT calls) are updated to use the OmpDeviceBuffer accessors or .data() so RAII
and consistent error handling are enforced.

In @src/poisson_solver_multigrid.cpp:
- Around line 2411-2442: The call to the NVHPC-only function
ompx_get_cuda_stream inside MultigridPoissonSolver::vcycle_graphed causes
compile errors on non-NVHPC toolchains; wrap the ompx_get_cuda_stream usage (and
the omp_stream variable declaration/logic) in an #ifdef __NVCOMPILER / #endif
guard (matching the guard already used in smooth_chebyshev) so that when
__NVCOMPILER is not defined the code falls back to the non-graphed vcycle path;
ensure the runtime null-check and warning remain inside the guarded block and
that vcycle(…) is called unconditionally outside the guard if
ompx_get_cuda_stream is not available (e.g., when USE_GPU_OFFLOAD is set but
__NVCOMPILER is false).

In @tests/bench_256.cpp:
- Around line 38-41: The benchmark sets the deprecated cfg.tol; replace that
with the new convergence criteria by removing cfg.tol and instead assign the new
fields on the PoissonConfig instance (e.g., cfg.tol_abs = 1e-10; cfg.tol_rhs =
1e-10; cfg.tol_rel = 0.0) while keeping cfg.max_iter and cfg.check_interval
unchanged so the benchmark uses the new API (PoissonConfig cfg and its
tol_abs/tol_rhs/tol_rel members) rather than the legacy tol field.

In @tests/bench_mg_cuda_graphs.cpp:
- Around line 25-44: In benchmark_grid, guard against trials <= 0 before any
access to the times array (or times[0]) by validating the trials parameter and
returning early or throwing; locate the trials usage in the function
benchmark_grid and add a check like if (trials <= 0) return; (or adjust to set
trials = 1) to avoid undefined behavior when indexing times; apply the same
guard to the analogous code block later (lines around the second benchmark
invocation at ~135-154) that also accesses times[0].
- Around line 13-24: The file uses std::getenv but lacks the <cstdlib> include;
add #include <cstdlib> to the top of tests/bench_mg_cuda_graphs.cpp (near the
other includes) so getenv is declared, and apply the same fix for the other
occurrences noted (around the 156–174 region) to avoid relying on transitive
includes.

In @tests/bench_mg_tuning.cpp:
- Around line 18-25: Add the missing headers and replace non-portable M_PI
usages: include <sstream> for std::ostringstream and <cstdlib> for std::atoi at
the top of tests/bench_mg_tuning.cpp, and replace every M_PI occurrence with a
portable constant (e.g., constexpr double PI = std::acos(-1); then use PI) or
directly use std::acos(-1) where needed; update all occurrences mentioned (the
include block and the other spots around the later uses) so builds on non‑GNU
toolchains succeed.

In @tests/test_mg_physics_match.cpp:
- Around line 194-237: The current pass/fail uses overly strict pure-relative
thresholds (ke_rel_diff, vel_rel_diff < 1e-6) and a fragile div_ratio =
fixed.final_div / ref.final_div that can explode when ref.final_div ≈ 0; change
the checks to combined absolute+relative comparisons with floors: replace
ke_rel_diff and vel_rel_diff tests with something like abs(a-b) < eps_abs +
eps_rel * abs(ref) (referencing fixed.final_KE/ref.final_KE and
fixed.final_max_vel/ref.final_max_vel) and replace the div_ratio test by
comparing abs(fixed.final_div - ref.final_div) < eps_div_abs + eps_div_rel *
max(abs(ref.final_div), small_floor), or use a small_floor for denominator
instead of dividing by ref.final_div; update the boolean passed and the failure
messages to report the absolute and relative thresholds used (symbols:
ke_rel_diff, vel_rel_diff, div_ratio, fixed.final_div, ref.final_div,
fixed.final_KE, ref.final_KE, fixed.final_max_vel, ref.final_max_vel).

In @tests/test_vcycle_graph_stress.cpp:
- Around line 149-245: Both solvers were constructed under the same
MG_USE_VCYCLE_GRAPH setting so they follow identical code paths; switch the mode
per-solve by setting the PoissonConfig flag instead. Before calling
mg_graphed.solve_device(...) set cfg.use_vcycle_graph = true, and before calling
mg_nongraphed.solve_device(...) set cfg.use_vcycle_graph = false (keeping
cfg.fixed_cycles as is); call the two solve_device calls with their respective
cfg instances so the graphed vs non-graphed behavior is controlled at solve time
rather than at solver construction.
🟡 Minor comments (5)
tests/test_vcycle_graph_stress.cpp-150-245 (1)

150-245: CPU build behavior is “pass via skip”; please print “SKIPPED” and return explicitly.
Right now it prints and breaks, but returns true by default—easy to miss in CI logs. Make the skip explicit in the return path.

Also applies to: 343-360

tests/bench_mg_cuda_graphs.cpp-45-99 (1)

45-99: Add runtime GPU device check via gpu::verify_device_available() in the GPU code path.

The GPU path needs early device availability validation. Add gpu::verify_device_available(); immediately after line 45 (the #ifdef USE_GPU_OFFLOAD directive) in benchmark_grid() to fail fast with a clear error message if no GPU device is available, rather than letting #pragma omp target enter data fail with a cryptic runtime error.

The existing gpu::verify_device_available() utility in src/gpu_init.cpp does exactly this—it checks omp_get_num_devices() > 0 and throws a descriptive error, following the pattern already used throughout the codebase (e.g., poisson_solver_multigrid.cpp, solver.cpp).

Also applies to: 156-191

src/poisson_solver_multigrid.cpp-1114-1187 (1)

1114-1187: Clarify ||r||_2 definition (Euclidean vs RMS) and align docs/thresholds.
r_l2 = sqrt(sum(r^2)) (not divided by N). That’s totally fine if b_l2_ uses the same convention (it does), but please document it explicitly in the API comments and config semantics to avoid users assuming RMS.

scripts/run_nsys_profiles.sh-168-188 (1)

168-188: Use grep -E (or avoid \|) for portable alternation.
Some greps don’t treat \| as alternation unless using ERE.

Proposed fix
@@
-                grep -i "poisson\|mg:\|fft" "$nvtx_file" 2>/dev/null | head -20 || echo "  No Poisson NVTX data"
+                grep -Ei "poisson|mg:|fft" "$nvtx_file" 2>/dev/null | head -20 || echo "  No Poisson NVTX data"
include/mg_cuda_kernels.hpp-27-32 (1)

27-32: Add documentation explaining the relationship to PoissonBC and clarify the enum's purpose.

The enum values are correctly aligned with PoissonBC (Dirichlet=0, Neumann=1, Periodic=2). However, as a public header interface, the enum needs more complete documentation. Explain why this separate enum exists (CUDA kernel specialization), document parameters and usage, and note any special handling for periodic boundary conditions.

🧹 Nitpick comments (29)
src/poisson_solver_fft2d.cpp (2)

132-140: Consider migrating to OmpDeviceBuffer for GPU memory management.

The code uses manual cudaMallocManaged calls. Per project guidelines, using the OmpDeviceBuffer wrapper would improve resource safety and consistency with the rest of the codebase.

Based on learnings, the project prefers OmpDeviceBuffer for GPU buffer management to ensure RAII-based cleanup and better interoperability with OpenMP target directives.


20-39: Add runtime device availability check.

The constructor initializes GPU resources without verifying device availability at runtime. Consider adding a check for omp_get_num_devices() > 0 before GPU initialization.

🔧 Suggested device availability check
 FFT2DPoissonSolver::FFT2DPoissonSolver(const Mesh& mesh)
     : mesh_(&mesh)
     , Nx_(mesh.Nx)
     , Ny_(mesh.Ny)
     , dx_(mesh.dx)
     , dy_(mesh.dy)
 {
     if (!mesh.is2D()) {
         throw std::runtime_error("FFT2DPoissonSolver requires 2D mesh (Nz=1)");
     }
 
+    if (omp_get_num_devices() <= 0) {
+        throw std::runtime_error("FFT2DPoissonSolver requires GPU device (USE_GPU_OFFLOAD)");
+    }
+
     N_modes_ = Nx_ / 2 + 1;

As per coding guidelines, GPU offload code should verify device availability at runtime.

src/gpu_kernels.cpp (1)

43-45: Same concern: alloc in map(present, alloc:) contradicts "already mapped" comments.

These kernels have identical pattern to src/poisson_solver_fft.cpp: comments claim arrays are "already mapped by solver/turbulence model" (lines 40, 104, 190, 239, 364, 866, 959), yet map clauses include alloc. This introduces ambiguity in the data lifecycle contract.

Recommendation: Ensure consistent data mapping strategy across all GPU kernels. If buffers are pre-mapped, use strict map(present:) to catch lifecycle bugs early.

Also applies to: 106-111, 192-192, 241-245, 368-370, 433-436, 804-804, 869-873, 961-963

include/poisson_solver.hpp (3)

18-32: Excellent documentation for the new convergence criteria.

The three-way convergence check (absolute, RHS-relative, initial-relative) with clear documentation is well-designed. The backwards compatibility via deprecated tol field is appropriate.

Minor observation: The tol_abs = 0.0 convention for "disabled" is documented but non-obvious. Consider whether a negative sentinel (e.g., -1.0) or an explicit boolean flag would be clearer for future maintainers.


34-43: Verify the adaptive cycle mode interaction and performance claims.

The fixed-cycle and adaptive-cycle modes introduce a moderately complex interaction pattern:

  • fixed_cycles = 0: convergence-based termination
  • fixed_cycles > 0 and adaptive_cycles = false: exactly N cycles
  • fixed_cycles > 0 and adaptive_cycles = true: run check_after cycles, check, add more if needed

The comments claim specific performance improvements ("16% faster + 58% better divergence"). Ensure these are representative across problem types and not just one specific case.

Consider adding runtime validation to prevent conflicting configurations (e.g., adaptive_cycles = true but fixed_cycles = 0). Also verify that the performance claims in comments are documented with the test conditions that produced them.


46-50: Clarify the "auto" behavior for smoother parameters.

Using 0 to indicate automatic selection for nu1 and nu2 is reasonable, but the documentation states "0 = auto: 3 for wall BCs" without specifying what value is chosen for other BC types.

Consider documenting all auto-selection cases or pointing to where this logic is implemented for maintainability.

tests/bench_256.cpp (1)

43-63: Consider adding solution validation to the benchmark.

While timing is the primary focus, the benchmark doesn't verify that the solver actually converged or that the solution is correct for the given RHS. This could mask solver issues.

Consider adding a simple correctness check, such as:

  • Verify the solver converged (iterations < max_iter)
  • Check final residual is below tolerance
  • For this specific RHS (trig function), verify the solution satisfies the expected analytical form

As per coding guidelines: "Add tests for new features."

profiles/taylor_green_fft_memops.txt (1)

1-17: Reconsider committing generated profiling artifacts to version control.

This file contains generated Nsight Systems profiling output. Generally, generated artifacts should not be committed to the repository because:

  1. They become stale as code evolves
  2. They expose absolute file paths from the developer's environment (Line 2-3)
  3. They can't be easily regenerated by other developers without the exact same setup
  4. They increase repository size without providing code functionality

If these profiling results are intended as documentation or reference data:

  • Move them to a documentation directory with context about when/how they were generated
  • Add a README explaining the profiling methodology and how to regenerate
  • Consider using relative paths or sanitizing absolute paths
  • Alternatively, document the profiling process and provide scripts to generate fresh profiles rather than committing the output

If these are temporary debugging artifacts, they should be git-ignored instead.

app/main_taylor_green_3d.cpp (1)

169-185: The T_final precedence logic is correct and well-structured.

The three-way precedence (command line > config file > default) is properly implemented and preserves backward compatibility. The conditional max_iter override (lines 180-185) is a good design choice, allowing config files to directly set max_iter for profiling without interference.

Minor observation: The logic assumes config.T_final > 0 distinguishes "set" from "unset." This is consistent with the default of -1.0 mentioned in the changes, but consider adding a validation check that the final T_final value is positive before use.

profiles/taylor_green_fft_kernels.txt (1)

1-79: Reconsider committing generated profiling artifacts to version control.

Similar to taylor_green_fft_memops.txt, this file contains generated Nsight Systems profiling output with kernel timing data. The same concerns apply regarding committing generated artifacts to version control.

See the recommendations in the review of profiles/taylor_green_fft_memops.txt. If profiling data is needed for documentation:

  • Provide context about the test conditions, hardware, and software versions
  • Include scripts to regenerate the profiles
  • Consider summarizing key findings in markdown documentation instead of raw tool output
  • If keeping raw output, ensure paths are sanitized and generation methodology is documented
docs/profiling_results_128cubed.md (1)

80-84: Add language specifier to fenced code block.

The fenced code block should specify a language identifier for proper syntax highlighting and rendering.

📝 Suggested fix
-```
+```text
 Graph executions: 600 (6 configs × 10 steps × 10 V-cycles)
 Average per graph: 299.5 μs (actual GPU compute)
 Graph launch overhead: ~10 μs

</details>

</blockquote></details>
<details>
<summary>tests/bench_mg_cuda_graphs.cpp (1)</summary><blockquote>

`45-99`: **Prefer `OmpDeviceBuffer` (avoid raw `target enter/exit data` in tests/bench).**  
This benchmark manually manages device mappings; the repo guideline prefers `OmpDeviceBuffer` for GPU buffers. Based on learnings, use `OmpDeviceBuffer` to reduce mapping lifetime/mismatch risks.

</blockquote></details>
<details>
<summary>tests/test_vcycle_graph_stress.cpp (2)</summary><blockquote>

`19-41`: **`compute_l2_residual()` is isotropic-only; keep it aligned with the discretization being tested.**  
It assumes a single spacing `h`. That’s OK for the cubic tests here, but it’s easy to reuse incorrectly later (you already had to special-case anisotropic residual in Test 3). Consider an overload taking `(hx,hy,hz)` and reusing it for Test 3 too.

---

`44-147`: **Prefer `OmpDeviceBuffer` and reduce H↔D copies in the stress tests.**  
You frequently `update to/from` full arrays. For tests this is acceptable, but `OmpDeviceBuffer` would make lifetime + “present” assumptions safer and reduce boilerplate. Based on learnings, prefer `OmpDeviceBuffer`.  



Also applies to: 247-341

</blockquote></details>
<details>
<summary>tests/test_mg_physics_match.cpp (1)</summary><blockquote>

`148-161`: **GPU path sync cost: `sync_from_gpu()` every step is heavy.**  
If you only need velocity for metrics, consider `sync_solution_from_gpu()` (or a narrower sync) to reduce transfer volume.

</blockquote></details>
<details>
<summary>app/profile_comprehensive.cpp (1)</summary><blockquote>

`218-221`: **Use `[[maybe_unused]]` instead of `(void)argc/argv`.**  
Matches the repo guideline for intentionally unused variables.

</blockquote></details>
<details>
<summary>src/solver.cpp (1)</summary><blockquote>

`4803-5030`: **Consider migrating manual enter/exit mappings to `OmpDeviceBuffer`.**  
This file now has a lot of low-level mapping lifetime management; the repo guideline prefers `OmpDeviceBuffer` to reduce mapping duplication bugs (like the `velocity_old_*` issue) and make ownership clearer. Based on learnings, prefer `OmpDeviceBuffer`.

</blockquote></details>
<details>
<summary>tests/bench_mg_tuning.cpp (1)</summary><blockquote>

`141-172`: **Benchmark labels vs IC/BC mismatch (can skew “quality” metrics).**  
You label this as Taylor–Green (div-free) but enforce no-slip walls on `y` (channel BC). That makes the “divergence-free” claim misleading and can inject divergence via BC enforcement, affecting `div_L2/Linf` comparisons between configs.



Also applies to: 151-173

</blockquote></details>
<details>
<summary>include/poisson_solver_multigrid.hpp (1)</summary><blockquote>

`204-214`: **Prefer `OmpDeviceBuffer` over raw device pointers for MG buffers (safety/RAII).**  
The header exposes raw pointer vectors (`u_ptrs_`, `f_ptrs_`, `r_ptrs_`, `tmp_ptrs_`), which drives manual `enter data` + `omp_target_alloc/free` in the .cpp. Based on learnings, consider wrapping these in `OmpDeviceBuffer` to avoid leaks on partial initialization and simplify cleanup.  
Based on learnings, use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management.

</blockquote></details>
<details>
<summary>src/poisson_solver_multigrid.cpp (3)</summary><blockquote>

`60-77`: **Avoid unconditional stdout in library code path (hurts benchmarks/tests).**  
`std::cout << "[MG] Full V-cycle CUDA Graph enabled ..."` will print for every solver construction when `USE_GPU_OFFLOAD` is on and the env var isn’t set.

---

`1411-1421`: **Jacobi path uses `degree` as “iterations per pass” (naming/intent mismatch).**  
With `MG_SMOOTHER=jacobi`, `degree` becomes Jacobi iteration count, not Chebyshev degree. That can be surprising when tuning `poisson_chebyshev_degree`. Consider a separate config knob (or rename parameter to something neutral like `smoother_param`).

---

`1955-1970`: **Avoid one-time D→H copy for `b_l2_` in `solve_device` adaptive mode (can be computed on device).**  
Right now you `target update from(f_ptrs_[0]...)` then do a CPU reduction. You already rely on OpenMP target reductions elsewhere; consider doing a device reduction for `b_sum_sq` after a `cudaDeviceSynchronize()` (or via the same stream discipline used for residual checks) to keep the adaptive path truly device-resident.

</blockquote></details>
<details>
<summary>src/mg_cuda_kernels.cpp (1)</summary><blockquote>

`431-561`: **Consider CUDA error checks after launches (especially during graph capture).**  
During capture/instantiate, a silent kernel launch failure can surface later as a confusing `cudaGraphInstantiate` error. Even a debug-only `CUDA_CHECK(cudaPeekAtLastError())` after key launches can save time.

</blockquote></details>
<details>
<summary>include/mg_cuda_kernels.hpp (6)</summary><blockquote>

`35-47`: **Consider clarifying ownership semantics for device pointers.**

The struct uses raw `double*` pointers for device memory (`u`, `f`, `r`, `tmp`). Per coding guidelines, `OmpDeviceBuffer` is preferred for GPU buffer management. If `LevelConfig` is merely a non-owning view/configuration passed to kernels, document this explicitly. If it owns the memory, consider using `OmpDeviceBuffer` or a similar RAII wrapper.

Based on learnings, the codebase prefers `OmpDeviceBuffer` wrapper for GPU buffer management.



<details>
<summary>📝 Suggested documentation clarification</summary>

```diff
 /// Per-level MG configuration for CUDA kernels
+/// Note: This struct holds non-owning views to device memory managed elsewhere.
+/// The caller is responsible for lifetime management of the pointed-to arrays.
 struct LevelConfig {

49-100: Enhance documentation for public API completeness.

Per coding guidelines, public functions should document parameters, return values, side effects, and include usage examples for complex functions. The class is missing:

  1. Return value documentation for is_valid() (currently implicit)
  2. Side effects documentation for initialize() (captures graph, allocates resources)
  3. Usage example showing typical lifecycle
📝 Suggested documentation enhancement
 /// CUDA Graph-based smoother for a single MG level
+///
+/// Usage example:
+/// @code
+///   CudaSmootherGraph smoother;
+///   smoother.initialize(config, degree, bc_x_lo, bc_x_hi, ...);
+///   if (smoother.is_valid()) {
+///       smoother.execute(stream);
+///   }
+///   smoother.destroy();  // or let destructor handle it
+/// @endcode
 class CudaSmootherGraph {

     /// Initialize for a given level configuration
+    /// @note Captures a CUDA graph internally; call once before execute()
     /// @param config Level parameters (grid size, pointers, etc.)
     /// @param degree Chebyshev polynomial degree
     /// @param bc_x_lo/hi, bc_y_lo/hi, bc_z_lo/hi Boundary conditions
+    /// @post Graph is captured and ready for execution via execute()

     /// Check if graph is initialized and valid
+    /// @return true if graph has been successfully captured and is ready to execute
     bool is_valid() const { return graph_exec_ != nullptr; }

125-133: Document bounds checking behavior for smooth() methods.

The smooth(int level) methods should document the valid range for level parameter and the behavior when an invalid level is passed (exception, assertion, undefined behavior).

📝 Suggested documentation
     /// Execute smoother for a given level (uses internal stream)
+    /// @param level Level index (0 = finest, must be < graphs initialized)
+    /// @pre level < smoother_graphs_.size()
     void smooth(int level);
 
     /// Execute smoother for a given level on specified stream
     /// Use this with OpenMP's stream to avoid cross-stream sync overhead
+    /// @param level Level index (0 = finest, must be < graphs initialized)
+    /// @param stream CUDA stream for execution
+    /// @pre level < smoother_graphs_.size()
     void smooth(int level, cudaStream_t stream);

212-226: Consider unifying LevelConfig and VCycleLevelConfig to reduce duplication.

Both structs share most fields (Nx, Ny, Nz, Ng, dx2, dy2, dz2, coeff, total_size, u, f, r, tmp). VCycleLevelConfig adds only inv_dx2, inv_dy2, inv_dz2. Consider:

  1. Adding the inv_* fields to LevelConfig and using one struct, or
  2. Having VCycleLevelConfig embed a LevelConfig

This would reduce maintenance burden and ensure consistency.

♻️ Option: Embed LevelConfig
 struct VCycleLevelConfig {
+    LevelConfig base;                  // Inherit common fields
     int Nx, Ny, Nz;           // Interior grid dimensions
-    int Ng;                   // Number of ghost cells
-    double inv_dx2, inv_dy2, inv_dz2;  // Inverse grid spacing squared
-    double dx2, dy2, dz2;     // Grid spacing squared (for smoother)
-    double coeff;             // Diagonal coefficient for Jacobi
-    size_t total_size;        // Total array size including ghosts
-
-    // Array pointers (device memory)
-    double* u;                // Solution array
-    double* f;                // RHS array
-    double* r;                // Residual array
-    double* tmp;              // Scratch buffer for smoother
+    double inv_dx2, inv_dy2, inv_dz2;  // Additional: inverse spacing squared
 };

270-327: Add usage example and clarify member defaults.

This is a complex class that would benefit from a usage example per coding guidelines. Also, the private member defaults (degree_ = 4, nu1_ = 2, nu2_ = 2) are overwritten by initialize(), which is fine, but worth noting.

📝 Suggested documentation enhancement
 /// Full V-cycle CUDA Graph - captures entire V-cycle for single-launch execution
+///
+/// Usage example:
+/// @code
+///   CudaVCycleGraph vcycle_graph;
+///   vcycle_graph.initialize(levels, degree, nu1, nu2, bc_x_lo, bc_x_hi, ...);
+///   
+///   // Check if recapture needed (e.g., grid changed)
+///   if (vcycle_graph.needs_recapture(current_fingerprint)) {
+///       vcycle_graph.destroy();
+///       vcycle_graph.initialize(...);
+///   }
+///   
+///   vcycle_graph.execute(stream);
+/// @endcode
 class CudaVCycleGraph {

289-290: Document stream requirements for execute().

Should document whether a null/default stream is acceptable and any synchronization requirements (e.g., must caller synchronize before accessing results?).

📝 Suggested documentation
     /// Execute the captured V-cycle graph
+    /// @param stream CUDA stream for execution (must not be nullptr)
+    /// @note Caller must synchronize stream before reading results
     void execute(cudaStream_t stream);
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51b27c5 and 55bba54.

📒 Files selected for processing (30)
  • CMakeLists.txt
  • app/main_taylor_green_3d.cpp
  • app/profile_comprehensive.cpp
  • docs/profiling_results_128cubed.md
  • include/config.hpp
  • include/gpu_utils.hpp
  • include/mg_cuda_kernels.hpp
  • include/poisson_solver.hpp
  • include/poisson_solver_multigrid.hpp
  • profiles/taylor_green_fft_kernels.txt
  • profiles/taylor_green_fft_memops.txt
  • profiles/taylor_green_fft_nvtx.txt
  • profiles/taylor_green_fft_stats.txt
  • scripts/ci.sh
  • scripts/run_nsys_profiles.sh
  • src/config.cpp
  • src/gpu_kernels.cpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_fft.cpp
  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_multigrid.cpp
  • src/solver.cpp
  • tests/bench_256.cpp
  • tests/bench_fft_vs_mg.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/bench_mg_tuning.cpp
  • tests/test_mg_physics_match.cpp
  • tests/test_poisson_unified.cpp
  • tests/test_vcycle_graph_stress.cpp
🧰 Additional context used
📓 Path-based instructions (7)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_fft2d.cpp
  • app/profile_comprehensive.cpp
  • src/poisson_solver_fft.cpp
  • app/main_taylor_green_3d.cpp
  • tests/bench_mg_tuning.cpp
  • tests/test_vcycle_graph_stress.cpp
  • src/gpu_kernels.cpp
  • tests/bench_256.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • src/config.cpp
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_fft2d.cpp
  • app/profile_comprehensive.cpp
  • src/poisson_solver_fft.cpp
  • app/main_taylor_green_3d.cpp
  • tests/bench_mg_tuning.cpp
  • include/poisson_solver.hpp
  • tests/test_vcycle_graph_stress.cpp
  • src/gpu_kernels.cpp
  • include/gpu_utils.hpp
  • tests/bench_256.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • include/config.hpp
  • src/config.cpp
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
  • include/mg_cuda_kernels.hpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_fft.cpp
  • src/gpu_kernels.cpp
  • src/config.cpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
app/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

Use solve_steady_with_snapshots() for automatic VTK output; specify num_snapshots in config; files are numbered sequentially plus final

Files:

  • app/profile_comprehensive.cpp
  • app/main_taylor_green_3d.cpp
tests/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

tests/**/*.cpp: Do not add platform-specific tolerances with #ifdef __APPLE__ or similar platform checks in tests; fix the root cause instead
Do not relax tolerances for different compilers; tests should be numerically robust across all compilers
Tests must pass in both Debug and Release builds; if different behavior is needed, investigate the root cause rather than accept build-type dependent behavior
Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid == for doubles)
For iterative solvers in tests, check residual convergence rather than exact iteration count
Ensure each test is independent and does not rely on execution order; clean up any state or files created during tests
Add tests for new features; update tests when changing behavior; do not commit broken tests

Files:

  • tests/bench_mg_tuning.cpp
  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_256.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
include/**/*.{hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Files:

  • include/poisson_solver.hpp
  • include/gpu_utils.hpp
  • include/config.hpp
  • include/mg_cuda_kernels.hpp
  • include/poisson_solver_multigrid.hpp
**/*.{sh,slurm,batch}

📄 CodeRabbit inference engine (.cursor/rules/lean-changes.mdc)

Always submit slurm jobs using the embers QOS

Files:

  • scripts/run_nsys_profiles.sh
  • scripts/ci.sh
🧠 Learnings (18)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_fft.cpp
  • src/gpu_kernels.cpp
  • include/gpu_utils.hpp
  • tests/bench_mg_cuda_graphs.cpp
  • include/mg_cuda_kernels.hpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_fft.cpp
  • src/gpu_kernels.cpp
  • include/gpu_utils.hpp
  • tests/bench_mg_cuda_graphs.cpp
  • scripts/ci.sh
  • include/mg_cuda_kernels.hpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_fft2d.cpp
  • app/profile_comprehensive.cpp
  • src/poisson_solver_fft.cpp
  • include/poisson_solver.hpp
  • tests/test_vcycle_graph_stress.cpp
  • src/gpu_kernels.cpp
  • tests/bench_256.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • include/config.hpp
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
  • src/solver.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_fft.cpp
  • src/gpu_kernels.cpp
  • include/gpu_utils.hpp
  • tests/bench_mg_cuda_graphs.cpp
  • include/mg_cuda_kernels.hpp
  • src/solver.cpp
  • src/mg_cuda_kernels.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • docs/profiling_results_128cubed.md
  • app/profile_comprehensive.cpp
  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/bench_mg_bc_sweep.cpp
  • scripts/ci.sh
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Profile before optimizing and document complexity assumptions; use appropriate data structures

Applied to files:

  • app/profile_comprehensive.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to app/**/*.cpp : Use `solve_steady_with_snapshots()` for automatic VTK output; specify `num_snapshots` in config; files are numbered sequentially plus final

Applied to files:

  • app/profile_comprehensive.cpp
  • app/main_taylor_green_3d.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/test_physics_validation*.cpp : Tests must verify the Navier-Stokes solver using comprehensive physics validation including Poiseuille Flow, Divergence-Free Constraint, Momentum Balance, Channel Symmetry, Cross-Model Consistency, and Sanity Checks

Applied to files:

  • app/profile_comprehensive.cpp
  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_256.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • include/config.hpp
  • scripts/ci.sh
  • tests/test_mg_physics_match.cpp
  • src/solver.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • app/profile_comprehensive.cpp
  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_256.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • scripts/ci.sh
  • tests/test_mg_physics_match.cpp
  • tests/bench_fft_vs_mg.cpp
  • src/solver.cpp
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Add tests for new features; update tests when changing behavior; do not commit broken tests

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_256.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/test_mg_physics_match.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
  • tests/bench_256.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/test_mg_physics_match.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Ensure each test is independent and does not rely on execution order; clean up any state or files created during tests

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
  • tests/test_mg_physics_match.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Tests must pass in both Debug and Release builds; if different behavior is needed, investigate the root cause rather than accept build-type dependent behavior

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not relax tolerances for different compilers; tests should be numerically robust across all compilers

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • tests/test_vcycle_graph_stress.cpp
  • tests/test_poisson_unified.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/test_mg_physics_match.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Always call `set_body_force()` for driven flows; initialize velocity field before solving; set turbulence model before first solver step

Applied to files:

  • src/gpu_kernels.cpp
  • include/config.hpp
  • src/solver.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Run `./test_before_ci.sh` before every push to repository to test Debug build, Release build, and all unit tests

Applied to files:

  • scripts/ci.sh
🧬 Code graph analysis (12)
app/profile_comprehensive.cpp (7)
tests/test_runner.hpp (1)
  • NoSlip (166-170)
tests/bench_mg_bc_sweep.cpp (1)
  • solver (73-73)
tests/bench_mg_tuning.cpp (1)
  • solver (139-139)
include/poisson_solver_hypre.hpp (1)
  • cfg (130-130)
tests/test_framework.hpp (1)
  • cout (254-258)
include/mg_cuda_kernels.hpp (1)
  • config (59-62)
include/config.hpp (1)
  • argc (158-158)
tests/bench_mg_tuning.cpp (4)
include/mg_cuda_kernels.hpp (1)
  • config (59-62)
app/profile_comprehensive.cpp (1)
  • solver (147-147)
tests/bench_mg_bc_sweep.cpp (1)
  • solver (73-73)
include/config.hpp (1)
  • argc (158-158)
tests/bench_256.cpp (2)
tests/bench_mg_cuda_graphs.cpp (3)
  • rhs (102-102)
  • mg (31-31)
  • p (102-102)
include/poisson_solver_hypre.hpp (1)
  • cfg (130-130)
tests/bench_mg_cuda_graphs.cpp (2)
include/poisson_solver.hpp (5)
  • rhs (81-81)
  • rhs (101-101)
  • rhs (104-104)
  • rhs (107-107)
  • p (98-98)
include/poisson_solver_multigrid.hpp (1)
  • rhs (46-46)
tests/test_poisson_unified.cpp (1)
include/poisson_solver_hypre.hpp (1)
  • cfg (130-130)
tests/bench_mg_bc_sweep.cpp (5)
include/poisson_solver_multigrid.hpp (1)
  • nu1 (153-153)
include/mg_cuda_kernels.hpp (1)
  • config (59-62)
app/profile_comprehensive.cpp (1)
  • solver (147-147)
tests/bench_mg_tuning.cpp (1)
  • solver (139-139)
tests/test_framework.hpp (1)
  • cout (254-258)
scripts/run_nsys_profiles.sh (3)
include/solver.hpp (1)
  • which (254-254)
tests/bench_fft_vs_mg.cpp (2)
  • main (88-122)
  • main (88-88)
tests/test_mg_physics_match.cpp (2)
  • main (170-238)
  • main (170-170)
tests/test_mg_physics_match.cpp (4)
include/mg_cuda_kernels.hpp (1)
  • config (59-62)
tests/bench_fft_vs_mg.cpp (1)
  • solver (37-37)
src/solver.cpp (2)
  • step (2538-3544)
  • step (2538-2538)
tests/test_framework.hpp (1)
  • cout (254-258)
include/mg_cuda_kernels.hpp (1)
src/mg_cuda_kernels.cpp (4)
  • CudaSmootherGraph (566-568)
  • CudaMGContext (698-700)
  • CudaMGContext (702-706)
  • CudaVCycleGraph (753-755)
src/mg_cuda_kernels.cpp (2)
include/mg_cuda_kernels.hpp (30)
  • stream (66-66)
  • stream (87-87)
  • stream (90-90)
  • stream (93-93)
  • stream (96-96)
  • stream (99-99)
  • stream (290-290)
  • stream (320-320)
  • stream (323-323)
  • stream (326-326)
  • CudaSmootherGraph (52-52)
  • CudaSmootherGraph (53-53)
  • graph_exec_ (69-69)
  • graph_exec_ (293-293)
  • config (59-62)
  • launch_chebyshev_3d (149-154)
  • launch_bc_3d (157-164)
  • launch_copy (167-170)
  • CudaMGContext (105-105)
  • CudaMGContext (106-106)
  • smoother_graphs_ (136-136)
  • level (126-126)
  • level (130-130)
  • level (133-133)
  • CudaVCycleGraph (273-273)
  • CudaVCycleGraph (274-274)
  • fingerprint_ (301-301)
  • launch_residual_3d (184-188)
  • launch_restrict_3d (192-197)
  • launch_prolongate_3d (201-206)
src/poisson_solver_multigrid.cpp (2)
  • smooth (942-1055)
  • smooth (942-942)
include/poisson_solver_multigrid.hpp (4)
include/gpu_utils.hpp (5)
  • USE_GPU_OFFLOAD (88-95)
  • USE_GPU_OFFLOAD (98-104)
  • USE_GPU_OFFLOAD (107-113)
  • USE_GPU_OFFLOAD (116-123)
  • USE_GPU_OFFLOAD (188-203)
include/poisson_solver.hpp (1)
  • residual_ (84-84)
include/poisson_solver_fft2d.hpp (1)
  • residual_ (49-49)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
src/poisson_solver_multigrid.cpp (3)
include/gpu_utils.hpp (4)
  • n (174-186)
  • n (174-174)
  • get_device_ptr (249-253)
  • get_device_ptr (249-249)
include/poisson_solver_multigrid.hpp (14)
  • level (88-88)
  • level (89-89)
  • level (160-160)
  • level (161-161)
  • level (162-162)
  • level (163-163)
  • level (166-166)
  • level (167-167)
  • level (168-168)
  • level (174-174)
  • level (179-179)
  • level (181-181)
  • level (202-202)
  • nu1 (153-153)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
🪛 markdownlint-cli2 (0.18.1)
docs/profiling_results_128cubed.md

80-80: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Agent
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)

Comment on lines 1 to 8

NOTICE: Existing SQLite export found: /storage/home/hcoda1/6/sbryngelson3/scratch/nncfd/profiles/taylor_green_fft.sqlite
It is assumed file was previously exported from: /storage/home/hcoda1/6/sbryngelson3/scratch/nncfd/profiles/taylor_green_fft.nsys-rep
Consider using --force-export=true if needed.

Processing [/storage/home/hcoda1/6/sbryngelson3/scratch/nncfd/profiles/taylor_green_fft.sqlite] with [/storage/pace-apps/manual/packages/nvhpc/25.5/Linux_x86_64/25.5/profilers/Nsight_Systems/host-linux-x64/reports/nvtx_sum.py]...

** NVTX Range Summary (nvtx_sum):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t commit generated profiling outputs containing user/host paths (PII + non-reproducible).
This file embeds absolute paths (including a username) and looks like a generated report; it should be removed from version control and regenerated locally, or sanitized into a stable, redacted summary doc.

🤖 Prompt for AI Agents
In @profiles/taylor_green_fft_stats.txt around lines 1 - 8, This generated
profiling output contains absolute user/host paths (e.g.
"/storage/home/hcoda1/6/sbryngelson3/scratch/nncfd/profiles/taylor_green_fft.sqlite"
and the nvtx_sum report) and must be removed from the repo; delete
profiles/taylor_green_fft_stats.txt from the commit, add a rule to ignore such
generated profiling outputs (e.g. add an entry to .gitignore for /profiles/*.txt
or your profiler output pattern), and replace it with a sanitized, stable
summary document (redacting usernames/host paths) or instructions on how to
regenerate the full report locally.

The lid-driven cavity test (2D) was failing with V-cycle graph enabled.
The 2D path in the V-cycle graph code path has not been thoroughly tested.
Disable V-cycle graph for 2D meshes (Nz == 1) until it can be validated.

3D cases continue to use V-cycle graph by default for maximum performance.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @src/poisson_solver_multigrid.cpp:
- Line 2037: Replace the assert(cfg.max_iter > 0 && "PoissonConfig.max_iter must
be positive"); with an exception throw to match the error handling used
elsewhere (see the check at line 1735): validate cfg.max_iter and if <= 0 throw
std::invalid_argument("PoissonConfig.max_iter must be positive") (or
std::runtime_error if that is the project convention), ensuring the message
mentions PoissonConfig.max_iter so callers get a clear runtime error instead of
an assert that disappears in release builds.
- Line 1735: Replace the assert(cfg.max_iter > 0 && "PoissonConfig.max_iter must
be positive") with a runtime check that throws an exception: add an if
(cfg.max_iter <= 0) throw std::invalid_argument("PoissonConfig.max_iter must be
positive: " + std::to_string(cfg.max_iter)); so configuration errors are
reported at runtime instead of being compiled out in release builds; reference
the cfg.max_iter field and PoissonConfig in the thrown message for clarity.
🧹 Nitpick comments (2)
src/poisson_solver_multigrid.cpp (2)

1692-1698: Consider using a sentinel value for uncomputed residuals.

Setting residual_ = 0.0 when convergence isn't checked could mislead callers into thinking the solution converged perfectly. Consider using std::numeric_limits<double>::quiet_NaN() or -1.0 to clearly indicate these values weren't computed.

♻️ Proposed alternative
-        // Set residual to 0 to indicate we didn't compute it
-        residual_ = 0.0;
-        residual_l2_ = 0.0;
-        r0_ = 0.0;
-        r0_l2_ = 0.0;
-        b_inf_ = 0.0;
-        b_l2_ = 0.0;
+        // Use NaN to indicate residuals weren't computed in fixed-cycle mode
+        constexpr double NOT_COMPUTED = std::numeric_limits<double>::quiet_NaN();
+        residual_ = NOT_COMPUTED;
+        residual_l2_ = NOT_COMPUTED;
+        r0_ = NOT_COMPUTED;
+        r0_l2_ = NOT_COMPUTED;
+        b_inf_ = NOT_COMPUTED;
+        b_l2_ = NOT_COMPUTED;

2387-2395: Clean up existing graph when disabling due to null pointers.

When null device pointers are detected, the code sets use_vcycle_graph_ = false but doesn't clean up the existing vcycle_graph_ if it exists. This could leave a stale graph object that's never used again but still consumes resources.

♻️ Proposed fix
     // Verify device pointers are valid
     for (size_t lvl = 0; lvl < configs.size(); ++lvl) {
         if (!configs[lvl].u || !configs[lvl].f || !configs[lvl].r || !configs[lvl].tmp) {
             std::cerr << "[MG] ERROR: Null device pointer at level " << lvl
                       << " - V-cycle Graph disabled\n";
+            // Clean up existing graph if any
+            if (vcycle_graph_) {
+                vcycle_graph_->destroy();
+                vcycle_graph_.reset();
+            }
             use_vcycle_graph_ = false;
             return;
         }
     }

Apply similar logic in initialize_cuda_graphs() at lines 2293-2299.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 55bba54 and 485551f.

📒 Files selected for processing (1)
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_multigrid.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_multigrid.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_multigrid.cpp
🧠 Learnings (7)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Use `[[maybe_unused]]` attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
🧬 Code graph analysis (1)
src/poisson_solver_multigrid.cpp (2)
include/poisson_solver_multigrid.hpp (15)
  • level (88-88)
  • level (89-89)
  • level (160-160)
  • level (161-161)
  • level (162-162)
  • level (163-163)
  • level (166-166)
  • level (167-167)
  • level (168-168)
  • level (174-174)
  • level (179-179)
  • level (181-181)
  • level (202-202)
  • nu1 (153-153)
  • residual_ (64-64)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (8)
src/poisson_solver_multigrid.cpp (8)

1556-1572: LGTM! Nullspace handling is correct.

The logic correctly identifies that Poisson problems with pure Neumann or Periodic boundary conditions have a constant nullspace, which needs to be removed by subtracting the mean. Any Dirichlet boundary condition pins the solution and eliminates the nullspace.


1792-1804: Verify denominator handling prevents division by zero.

The convergence checks add small epsilon values (1e-30) to denominators to prevent division by zero. However, verify this epsilon is appropriate for the expected range of norm values in your application.

Analysis:

  • Line 1795: r_norm <= cfg.tol_rhs * (b_norm + 1e-30)
  • Line 1798: r_norm <= cfg.tol_rel * (r0_norm + 1e-30)

If b_norm or r0_norm are genuinely zero (e.g., zero RHS problem), adding 1e-30 still allows convergence. However, 1e-30 is extremely small and might not prevent underflow/precision issues on all platforms.

Recommendation: Consider using std::numeric_limits<double>::epsilon() or a problem-scale-dependent epsilon instead of a hardcoded 1e-30.


1940-1942: Good synchronization practice before OpenMP reductions.

The explicit cudaDeviceSynchronize() call before OpenMP target reductions is critical when mixing CUDA Graphs with OpenMP target. CUDA Graphs may use a different stream than OpenMP target regions, so this synchronization prevents race conditions.


2427-2440: Excellent defensive programming for stream validation.

The runtime validation of the OpenMP CUDA stream with graceful fallback is well-implemented. The single warning flag prevents log spam while still alerting users to a potential runtime issue. This is a good pattern for production code where crashing isn't acceptable but users should be informed of degraded performance.


1114-1187: LGTM! Fused residual and norm computation is an excellent optimization.

Computing the residual and both L∞ and L2 norms in a single pass over memory is much more efficient than separate compute_residual() + compute_max_residual() calls. The OpenMP reductions for max and sum are correctly structured for both 2D and 3D cases.


196-200: Map clause changes are correct per OpenMP 5.0+ specification.

The changes from map(present: ...) to map(present, alloc: ...) throughout the file are correct. The present, alloc mapping mode indicates that:

  • The pointer must already be present on the device (mapped via target enter data or omp_target_alloc)
  • Only reference the device copy, don't transfer data

This is the proper syntax for accessing device-resident data without triggering unnecessary data transfers.

Based on coding guidelines: Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management.

Also applies to: 432-432, 451-451, 475-475, 731-735, 831-839, 973-973, 991-991, 1012-1012, 1033-1033, 1084-1084, 1098-1098, 1144-1144, 1164-1164, 1217-1217, 1237-1237, 1310-1310, 1343-1343, 1459-1459, 1510-1510, 1529-1529, 1593-1593, 1606-1606, 1617-1617, 1631-1631, 1880-1880, 2020-2020, 2060-2060, 2071-2071, 2145-2145, 2197-2197


98-100: Verify nullness before calling destroy() on vcycle_graph_.

The code calls vcycle_graph_->destroy() without checking if vcycle_graph_ is nullptr. While vcycle_graph_ is initialized to nullptr in the header, it's only created under specific conditions (3D, enabled via env var). If set_bc() is called when vcycle_graph_ is nullptr, this will crash.

🐛 Proposed fix
 #ifdef USE_GPU_OFFLOAD
     // Re-initialize CUDA Graphs with new BCs
     if (use_cuda_graphs_) {
         initialize_cuda_graphs();
     }
     // Invalidate V-cycle graph so it gets recaptured with new BCs
-    if (vcycle_graph_) {
+    if (vcycle_graph_ && vcycle_graph_->is_valid()) {
         vcycle_graph_->destroy();
     }
 #endif

Apply the same fix to both set_bc() overloads (lines 98-100 and 120-122).

Also applies to: 120-122

⛔ Skipped due to learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

2059-2081: This review comment is based on a mischaracterization of the code. The pragmas use OpenMP's built-in max reduction operator (reduction(max: b_inf_local)), which is fully portable across all OpenMP versions supporting target offloading. The std::max and std::abs calls appear in the loop body to compute values, not in the reduction clause specification itself. This is standard C++ and poses no portability concerns.

Likely an incorrect or invalid review comment.

- Replace map(present, alloc:) with map(present:) in MG solver
  The 'alloc' modifier is redundant/incorrect when 'present' is used
  since 'present' asserts data is already mapped to device

- Move benchmark executables inside if(BUILD_TESTS) guard
  bench_mg_cuda_graphs, bench_256, test_mg_physics_match, etc.
  were previously built unconditionally

- Wrap test_vcycle_graph_stress with if(USE_GPU_OFFLOAD)
  This GPU-specific test should only build when GPU is enabled

Performance verified: no regression (52.25ms -> 52.31ms, within noise)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/poisson_solver_multigrid.cpp (2)

1393-1421: vcycle(..., degree): semantics are reasonable, but guard against invalid degree and avoid the (int)size_c cast.

  • If degree <= 0, the Chebyshev path (and Jacobi-iterations-as-degree path) becomes a silent no-op; consider asserting degree > 0 (or clamping to 1).
  • (int)size_c risks overflow; safer to iterate with size_t.
Proposed fix
 void MultigridPoissonSolver::vcycle(int level, int nu1, int nu2, int degree) {
     NVTX_SCOPE_POISSON("mg:vcycle");
+    assert(degree > 0 && "chebyshev_degree must be positive");
@@
 #ifdef USE_GPU_OFFLOAD
         assert(gpu_ready_ && "GPU must be initialized");
         const size_t size_c = level_sizes_[level + 1];
         double* u_coarse = u_ptrs_[level + 1];
@@
         #pragma omp target teams distribute parallel for \
             map(present: u_coarse[0:size_c])
-        for (int idx = 0; idx < (int)size_c; ++idx) {
+        for (size_t idx = 0; idx < size_c; ++idx) {
             u_coarse[idx] = 0.0;
         }
 #else

Also applies to: 1428-1433, 1449-1473, 1476-1487


1888-2151: Same safety-cap bug exists in solve_device() convergence mode.

Apply the same cfg.tol_rhs > 0.0 guard there to avoid “never converge unless residual==0” behavior when only tol_rel/tol_abs are set.

Proposed fix
-            if (converged && cfg.use_l2_norm && cfg.linf_safety_factor > 0.0) {
+            if (converged && cfg.use_l2_norm && cfg.linf_safety_factor > 0.0 && cfg.tol_rhs > 0.0) {
                 double linf_ratio = residual_ / (b_inf_ + 1e-30);
                 double linf_cap = cfg.tol_rhs * cfg.linf_safety_factor;
                 if (linf_ratio > linf_cap) {
                     converged = false;  // L∞ still too high, keep iterating
                 }
             }

Also applies to: 2122-2130

🤖 Fix all issues with AI agents
In @CMakeLists.txt:
- Around line 518-519: The new executable target bench_mg_cuda_graphs must be
guarded by the GPU offload flag; move the add_executable and
target_link_libraries calls for bench_mg_cuda_graphs into the existing
conditional block that is controlled by USE_GPU_OFFLOAD (the same block that
contains test_vcycle_graph_stress) and remove the duplicate unguarded
add_executable/target_link_libraries entries so the CUDA-graph-dependent
benchmark is only built when USE_GPU_OFFLOAD is enabled.

In @src/poisson_solver_multigrid.cpp:
- Around line 672-695: The CUDA-graphed smoother currently captures/replays a
hard-coded Chebyshev degree of 4, so cuda_ctx_->smooth (which replays
smooth_chebyshev(level, degree)) can silently ignore cfg.chebyshev_degree;
update the capture logic to use cfg.chebyshev_degree when building the CUDA
graph and detect changes to cfg.chebyshev_degree so you either recapture the
graph or disable CUDA graphs if the degree differs, or alternatively assert/emit
a clear runtime error when MG_USE_CUDA_GRAPHS is enabled but
cfg.chebyshev_degree != 4; locate the graph-capture and replay code paths (calls
to cuda_ctx_->smooth and smooth_chebyshev(level, degree)) and ensure the degree
parameter is sourced from cfg and tracked for changes before replaying the
graph.
- Around line 7-31: Update the GPU Synchronization Note to reflect that the
V-cycle CUDA Graph path is now the default in GPU builds instead of a mere
“future option”: change the phrasing around CUDA Graphs from prospective to
present-tense (e.g., “CUDA Graphs: V-cycle captured and replayed to reduce
stream syncs — currently used in GPU builds”), keep the explanation why OpenMP
nowait is avoided (reference MG kernels and nowait), and mark the other
approaches (OpenMP depend clauses, custom streams) as alternative or future
options; ensure the note points readers to the V-cycle CUDA Graph implementation
used by the multigrid pipeline so they can find the related code paths.
- Around line 2413-2444: vcycle_graphed() calls the NVHPC-only function
ompx_get_cuda_stream() unguarded; wrap the call and any use of
ompx_get_cuda_stream/omp_stream in a compile-time guard (#ifdef __NVCOMPILER) so
non-NVHPC builds never reference that symbol, and in the #else path immediately
fall back to the non-graphed V-cycle by calling vcycle(0, vcycle_graph_nu1_,
vcycle_graph_nu2_, vcycle_graph_degree_); ensure the existing runtime
null-stream check remains inside the NVCOMPILER branch so only NVHPC builds
perform stream retrieval and graph execution via
vcycle_graph_->execute(omp_stream).
🧹 Nitpick comments (5)
CMakeLists.txt (1)

520-529: Consider test registration and naming consistency.

Two minor organizational observations:

  1. test_mg_physics_match (Line 522) uses the test_ prefix but is not registered with add_test(), so it won't be executed by CTest. If this is intentional (i.e., it's a benchmark), consider renaming it to bench_mg_physics_match for consistency.

  2. The bench_* executables are placed under BUILD_TESTS but aren't registered as tests. If these are meant to be manual benchmarks rather than automated tests, this is acceptable, but the organization could be clearer (e.g., a comment explaining they're benchmarks, not automated tests).

src/poisson_solver_multigrid.cpp (4)

60-77: Defaulting V-cycle graphs ON + unconditional std::cout may be noisy in production runs.

Suggest gating the banner behind a verbosity flag / logger, or printing only once when the graph is actually captured (vs. merely “enabled by default”).


92-101: Graph invalidation on BC changes: ensure state is fully reset (not just destroyed).

vcycle_graph_->destroy() without resetting the owning pointer can be fine if is_valid() reliably reflects the destroyed state, but it’s easy for drift between “destroyed” vs. “needs_recapture” logic to cause surprises. Consider vcycle_graph_.reset() (or an explicit “invalidate” API) to make the lifecycle unambiguous.

Also applies to: 114-123


2251-2308: CUDA graph init: duplicated BC conversion + fingerprinting looks fine; consider consolidating to_cuda_bc and documenting what forces recapture.

Small maintainability win: extract to_cuda_bc into a shared private helper to prevent drift between smoother graphs and V-cycle graphs.

Also applies to: 2310-2411


1888-2026: Adaptive fixed-cycle GPU path: good attention to stream synchronization, but consider minimizing D→H transfers.

You currently target update from(f_ptrs_[0]...) once to compute b_l2_ on CPU. If this becomes a perf issue, you already have the on-device reduction logic used later in convergence-based mode—could reuse it here too. Based on coding guidelines, minimizing CPU↔GPU transfers is preferred.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 485551f and ae03607.

📒 Files selected for processing (2)
  • CMakeLists.txt
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_multigrid.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_multigrid.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_multigrid.cpp
🧠 Learnings (17)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Use `[[maybe_unused]]` attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Check return values and use exceptions for error conditions; provide informative error messages

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Tests must pass in both Debug and Release builds; if different behavior is needed, investigate the root cause rather than accept build-type dependent behavior

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Add tests for new features; update tests when changing behavior; do not commit broken tests

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not relax tolerances for different compilers; tests should be numerically robust across all compilers

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Always call `set_body_force()` for driven flows; initialize velocity field before solving; set turbulence model before first solver step

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to app/**/*.cpp : Use `solve_steady_with_snapshots()` for automatic VTK output; specify `num_snapshots` in config; files are numbered sequentially plus final

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/test_physics_validation*.cpp : Tests must verify the Navier-Stokes solver using comprehensive physics validation including Poiseuille Flow, Divergence-Free Constraint, Momentum Balance, Channel Symmetry, Cross-Model Consistency, and Sanity Checks

Applied to files:

  • CMakeLists.txt
🧬 Code graph analysis (1)
src/poisson_solver_multigrid.cpp (4)
include/gpu_utils.hpp (4)
  • n (174-186)
  • n (174-174)
  • get_device_ptr (249-253)
  • get_device_ptr (249-249)
include/poisson_solver_multigrid.hpp (15)
  • level (88-88)
  • level (89-89)
  • level (160-160)
  • level (161-161)
  • level (162-162)
  • level (163-163)
  • level (166-166)
  • level (167-167)
  • level (168-168)
  • level (174-174)
  • level (179-179)
  • level (181-181)
  • level (202-202)
  • nu1 (153-153)
  • residual_ (64-64)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
src/poisson_solver.cpp (2)
  • apply_bc (30-172)
  • apply_bc (30-30)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (6)
CMakeLists.txt (2)

251-254: LGTM!

The expansion of GPU sources to include mg_cuda_kernels.cpp is correct and properly guarded by USE_GPU_OFFLOAD.


532-539: LGTM! Exemplary GPU-specific test configuration.

The test_vcycle_graph_stress executable demonstrates the correct pattern for GPU-specific tests:

  • Properly guarded by USE_GPU_OFFLOAD
  • Registered with CTest
  • Environment variable configured (MG_USE_VCYCLE_GRAPH=1)
  • Labeled as "gpu" for test categorization

This is the pattern that should be followed for other GPU-specific executables.

src/poisson_solver_multigrid.cpp (4)

1556-1572: Nullspace detection/fix looks correct; consider whether “any Dirichlet face” matches all supported mixed-BC cases.

If you have edge/corner-only Dirichlet pinning (or other exotic setups), has_nullspace() may be too coarse. If not supported, LGTM as-is. Based on learnings, periodic BC handling is a common source of subtle Poisson issues—this helper is a good step.


1642-1852: Test ask: please run the repo’s GPU CI preflight script given the scale of GPU/offload changes.

Based on learnings, this PR touches solver correctness + GPU execution order; it’s worth running the full local GPU suite (physics validation + CPU/GPU consistency).

Also applies to: 1854-2151


1676-1725: Critical: L∞ safety-cap can prevent convergence when tol_rhs==0 (even if tol_rel/tol_abs converged).

linf_cap = cfg.tol_rhs * cfg.linf_safety_factor becomes 0 when tol_rhs is disabled; then any nonzero residual fails the cap and forces iterations to max. Guard the cap on tol_rhs > 0.

Proposed fix
-            if (converged && cfg.use_l2_norm && cfg.linf_safety_factor > 0.0) {
+            if (converged && cfg.use_l2_norm && cfg.linf_safety_factor > 0.0 && cfg.tol_rhs > 0.0) {
                 double linf_ratio = residual_ / (b_inf_ + 1e-30);
                 double linf_cap = cfg.tol_rhs * cfg.linf_safety_factor;
                 if (linf_ratio > linf_cap) {
                     converged = false;  // L∞ still too high, keep iterating
                 }
             }

Also applies to: 1727-1852, 1806-1814

⛔ Skipped due to learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

2153-2210: GPU device availability check is properly implemented.

The gpu::verify_device_available() function in src/gpu_init.cpp correctly verifies omp_get_num_devices() > 0 at runtime, throwing a descriptive error if no GPU devices are found. This satisfies the GPU offload robustness requirement. The call at line 2158 ensures the check occurs before any device allocation.

Removed:
- Legacy smoother-only CUDA graphs (use_cuda_graphs_, MG_USE_CUDA_GRAPHS)
  Superseded by full V-cycle graphs which are more efficient
- Unused smooth() function (Red-Black Gauss-Seidel, ~115 lines)
  Chebyshev and Jacobi smoothers are used instead
- MG_USE_VCYCLE_GRAPH environment variable override
  Config option poisson_use_vcycle_graph is sufficient
- Redundant bench_256.cpp benchmark
- Noisy "[MG] Full V-cycle CUDA Graph enabled" constructor print

Cleanup reduces maintenance burden and improves signal-to-noise ratio.
All tests pass, no performance regression.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@comp-physics comp-physics deleted a comment from qodo-code-review bot Jan 9, 2026
@comp-physics comp-physics deleted a comment from coderabbitai bot Jan 9, 2026
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
include/poisson_solver_multigrid.hpp (1)

48-60: solve_device doc: clarify array sizing for 3D and remove “not implemented” ambiguity.

The parameter-size notes currently read like 2D (Nx+2 * Ny+2) and can be misread due to missing parentheses. Suggest documenting as (Nx+2)*(Ny+2)*(Nz+2) (with Nz=1 in 2D), and either remove the “not implemented” note or point to a tracked issue.

src/poisson_solver_multigrid.cpp (1)

950-1023: Fused residual+norms: add NaN/Inf propagation like compute_max_residual() to avoid false convergence.

If r becomes NaN, sum_sq becomes NaN and r_l2 becomes NaN, and your convergence logic may behave unpredictably. compute_max_residual() already explicitly guards this.

Target-friendly NaN guard (no std::isfinite)
                 double r = f_ptr[idx] - laplacian;
                 r_ptr[idx] = r;
-                // Compute norms
-                double abs_r = (r >= 0.0) ? r : -r;
-                if (abs_r > max_res) max_res = abs_r;
-                sum_sq += r * r;
+                // Compute norms (+ NaN detection)
+                if (!(r == r)) { // NaN
+                    max_res = 1e308;
+                    sum_sq = 1e308;
+                } else {
+                    double abs_r = (r >= 0.0) ? r : -r;
+                    if (abs_r > max_res) max_res = abs_r;
+                    sum_sq += r * r;
+                }
CMakeLists.txt (1)

251-259: Make CUDA Toolkit an explicit requirement for USE_GPU_OFFLOAD builds.

src/mg_cuda_kernels.cpp includes <cuda_runtime.h> and uses CUDA APIs (cudaError_t, cudaGetErrorString), but CMake only calls find_package(CUDAToolkit) later for optional FFT/HYPRE features. A pure GPU offload build (USE_GPU_OFFLOAD=ON without HYPRE) will fail to compile on systems without CUDA Toolkit because the required headers won't be found.

CMake fix
 if(USE_GPU_OFFLOAD)
+    find_package(CUDAToolkit REQUIRED)
     target_sources(nn_cfd_core PRIVATE
         src/gpu_init.cpp
         src/mg_cuda_kernels.cpp
     )

Also run ./test_before_ci_gpu.sh locally to catch CPU/GPU consistency regressions early, per repository guidelines.

🤖 Fix all issues with AI agents
In @src/poisson_solver_multigrid.cpp:
- Around line 1725-1863: Adaptive fixed-cycle branch computes only b_l2_ and
uses L2-relative checks regardless of cfg.use_l2_norm, leaving b_inf_ unset;
update the adaptive path in solve_device so it honors cfg.use_l2_norm by
computing the appropriate RHS norm (b_inf_ for L∞ or b_l2_ for L2) before
starting checks and ensure the residual check uses residual_l2_/b_l2_ or
residual_inf_/b_inf_ depending on cfg.use_l2_norm; specifically, after the host
sync of f_ptrs_[0] (the block that computes b_l2_) add computation of b_inf_
when cfg.use_l2_norm is false (or switch the check to use infinity norms), and
ensure any later calls to compute_residual_and_norms(0, residual_, residual_l2_)
are replaced or complemented by a call that produces the matching residual_inf_
(or use an overload that fills both) so b_inf_ and residual_inf_ are valid when
needed.
- Around line 1567-1652: The L∞ safety cap uses cfg.tol_rhs even when
tol_rhs==0, which can undo convergence triggered by tol_abs or tol_rel; fix by
applying the linf safety check only when RHS-relative tolerance is active
(cfg.tol_rhs > 0). Concretely, change the condition around the cap (the block
that checks cfg.use_l2_norm && cfg.linf_safety_factor > 0.0) to also require
cfg.tol_rhs > 0.0 before computing linf_cap = cfg.tol_rhs *
cfg.linf_safety_factor and comparing linf_ratio = residual_ / (b_inf_ + 1e-30),
and apply the same guard in the duplicate block around lines ~1928-1968.
- Around line 36-43: The unconditional inclusion of <omp.h> must be wrapped in
the same USE_GPU_OFFLOAD guard as gpu_utils.hpp and any OpenMP-specific calls
should be compiled only when USE_GPU_OFFLOAD is defined; also protect the
NVHPC-specific ompx_get_cuda_stream() usage with an explicit NVHPC vendor guard
(e.g., check for the NVHPC compiler macro) so Clang/GCC offload toolchains that
lack ompx_get_cuda_stream() won’t build-fail. Locate the top-level include of
<omp.h> and move it inside the #ifdef USE_GPU_OFFLOAD block, and locate the call
to ompx_get_cuda_stream() and wrap it with both USE_GPU_OFFLOAD and an
NVHPC-specific macro check (only call ompx_get_cuda_stream() under those
guards), providing a portable fallback or conditional compilation path for other
compilers/toolchains.
- Around line 2088-2222: The call to ompx_get_cuda_stream in vcycle_graphed()
must be guarded for NVHPC-only builds; wrap the runtime stream retrieval and
null-check in an #ifdef __NVCOMPILER__ block and in the #else path immediately
fall back to the non-graphed path by calling vcycle(0, vcycle_graph_nu1_,
vcycle_graph_nu2_, vcycle_graph_degree_) and return (so non-NVHPC builds never
reference ompx_get_cuda_stream). Also wrap the std::cout capture-status line in
initialize_vcycle_graph() (the "[MG] Full V-cycle CUDA Graph ..." message)
behind the existing verbosity/logging flag (e.g., verbose_ or a logger) so test
output stays quiet. Ensure references are applied to the vcycle_graphed(),
initialize_vcycle_graph(), ompx_get_cuda_stream, and vcycle(...) usages.
🧹 Nitpick comments (5)
CMakeLists.txt (1)

341-344: Consider gating profiling executables behind an option (or exclude from default builds).

profile_comprehensive is always added. If this pulls in heavy deps (NVTX/CUDA symbols) or slows “normal” builds, consider an option like BUILD_PROFILES or NN_CFD_BUILD_TOOLS.

include/poisson_solver_multigrid.hpp (2)

62-79: Public getter docs should specify validity/semantics (“last solve”, fixed-cycle behavior).

Per header guidelines, please add brief API docs covering:

  • when these values are set (after solve() / solve_device()),
  • what happens in fixed-cycle mode (currently set to 0.0 in some paths),
  • which norm is returned (∞ vs L2).
    As per coding guidelines, public functions should document return value and side effects.

Also applies to: 134-140


143-152: Default-on CUDA graph: document constraints (e.g., 3D-only) and fallback behavior.

The .cpp notes “V-cycle graph is 3D only (2D path not fully tested)”. Mirror that in the header-level member/docs so users don’t assume it’s always active/valid.

src/poisson_solver_multigrid.cpp (2)

76-81: Graph invalidation on BC changes is good; consider also resetting the pointer to force recapture.

vcycle_graph_->destroy() likely works, but vcycle_graph_.reset() makes the lifecycle unambiguous (and avoids any “destroyed but fingerprint unchanged” edge cases depending on is_valid() semantics).

Also applies to: 94-99


1990-2047: GPU buffer management uses raw omp_target_alloc; consider moving to OmpDeviceBuffer wrapper per guidelines.

This is workable, but the repo guideline prefers OmpDeviceBuffer for RAII and fewer lifecycle hazards (especially if exceptions are thrown mid-initialization). As per coding guidelines.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ae03607 and 65702d4.

📒 Files selected for processing (3)
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (4)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_multigrid.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_multigrid.cpp
  • include/poisson_solver_multigrid.hpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_multigrid.cpp
include/**/*.{hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Files:

  • include/poisson_solver_multigrid.hpp
🧠 Learnings (18)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Check return values and use exceptions for error conditions; provide informative error messages

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.{cpp,hpp,h} : Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • include/poisson_solver_multigrid.hpp
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Tests must pass in both Debug and Release builds; if different behavior is needed, investigate the root cause rather than accept build-type dependent behavior

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Add tests for new features; update tests when changing behavior; do not commit broken tests

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not relax tolerances for different compilers; tests should be numerically robust across all compilers

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Always call `set_body_force()` for driven flows; initialize velocity field before solving; set turbulence model before first solver step

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to app/**/*.cpp : Use `solve_steady_with_snapshots()` for automatic VTK output; specify `num_snapshots` in config; files are numbered sequentially plus final

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/test_physics_validation*.cpp : Tests must verify the Navier-Stokes solver using comprehensive physics validation including Poiseuille Flow, Divergence-Free Constraint, Momentum Balance, Channel Symmetry, Cross-Model Consistency, and Sanity Checks

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.803Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to include/**/*.{hpp,h} : Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Applied to files:

  • include/poisson_solver_multigrid.hpp
🧬 Code graph analysis (2)
src/poisson_solver_multigrid.cpp (2)
include/gpu_utils.hpp (4)
  • n (174-186)
  • n (174-174)
  • get_device_ptr (249-253)
  • get_device_ptr (249-249)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
include/poisson_solver_multigrid.hpp (4)
include/gpu_utils.hpp (5)
  • USE_GPU_OFFLOAD (88-95)
  • USE_GPU_OFFLOAD (98-104)
  • USE_GPU_OFFLOAD (107-113)
  • USE_GPU_OFFLOAD (116-123)
  • USE_GPU_OFFLOAD (188-203)
include/poisson_solver.hpp (1)
  • residual_ (84-84)
include/poisson_solver_fft2d.hpp (1)
  • residual_ (49-49)
include/mg_cuda_kernels.hpp (3)
  • level (126-126)
  • level (130-130)
  • level (133-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (5)
CMakeLists.txt (1)

516-535: Benchmarks under BUILD_TESTS: verify CI intent + runtime expectations.

You’re building several bench_* targets when BUILD_TESTS=ON but not registering them with add_test() (which is fine). Please verify CI doesn’t accidentally execute them, and that build time impact is acceptable.

If you want benches always available locally but not in CI, consider a separate BUILD_BENCHMARKS toggle.

include/poisson_solver_multigrid.hpp (2)

9-14: Forward-decl of mg_cuda::CudaVCycleGraph looks right (keeps header lightweight).


171-197: Methods are correctly private; no visibility issue exists.

The three methods (compute_residual_and_norms(), has_nullspace(), fix_nullspace()) are already marked private: (line 90 in the header) and are only used internally within the implementation file. Test files do not reference these methods, so there is no risk of compile breaks from keeping them private. These are implementation details of the solver, not part of the public API.

Likely an incorrect or invalid review comment.

src/poisson_solver_multigrid.cpp (2)

2079-2087: Minor: sync_level_from_gpu() has an extra blank line; otherwise OK.


1883-1922: Use manual absolute value to match existing pattern in compute_residual_and_norms.

The code at lines 1901 and 1916 uses std::max and std::abs inside the target region, but the existing compute_residual_and_norms function in the same file explicitly avoids these and uses manual operations instead (double abs_r = (r >= 0.0) ? r : -r; with manual comparison). This inconsistency should be resolved for portability and maintainability. Apply the same manual pattern here:

double aval = (val >= 0.0) ? val : -val;
if (aval > b_inf_local) b_inf_local = aval;
⛔ Skipped due to learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.803Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

- Guard ompx_get_cuda_stream with __NVCOMPILER for non-NVHPC compilers
- Guard omp.h include behind USE_GPU_OFFLOAD for OpenMP-free CPU builds
- Fix L∞ safety cap bug: skip when tol_rhs==0 to prevent blocking convergence
- Fix boundary write races in bc_3d_kernel (Y/Z faces skip edges)
- Add nullptr check in get_device_ptr()
- Reset vcycle_graph_ after destroy() for cleaner state
- Move bench_mg_cuda_graphs inside USE_GPU_OFFLOAD guard in CMake

Note: Kept map(present, alloc:) syntax - changing to map(present:)
caused 120x performance regression on NVHPC due to implicit tofrom sync.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @include/gpu_utils.hpp:
- Around line 285-289: The GPU_PARALLEL_FOR_ASYNC macro is only defined under
USE_GPU_OFFLOAD causing compile failures when that flag is absent; add an #else
fallback that defines GPU_PARALLEL_FOR_ASYNC to a portable CPU implementation
(e.g., a plain for-loop or an OpenMP host parallel for without target/nowait) so
code using GPU_PARALLEL_FOR_ASYNC compiles both with and without offload; update
the header near the existing macro definition and mirror any comments (keep
behavior note about asynchronous kernels and remind callers to use gpu::sync()
where applicable).
🧹 Nitpick comments (2)
src/mg_cuda_kernels.cpp (2)

614-619: Verify eigenvalue bounds are appropriate for your discretization.

The Chebyshev eigenvalue bounds (lambda_min = 0.05, lambda_max = 1.95) are hardcoded here and in multiple other locations (lines 677, 898). While these conservative bounds work for most cases, consider:

  1. Documenting the derivation or reference for these values
  2. Extracting them as named constants to ensure consistency
Example: Extract as constants
// At namespace level or in a header
namespace {
    // Conservative eigenvalue bounds for D^{-1}*A where D = diag(A)
    // For 5/7-point discrete Laplacian, eigenvalues are in (0, 2)
    // These bounds ensure Chebyshev acceleration is stable across grid sizes and BCs
    constexpr double CHEBYSHEV_LAMBDA_MIN = 0.05;
    constexpr double CHEBYSHEV_LAMBDA_MAX = 1.95;
}

819-819: Consider making coarse solver iterations configurable.

The coarse solver iterations are hardcoded to 8 in the fingerprint. While this works, it could limit flexibility for problems requiring more accurate coarse solves.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 65702d4 and 26ed4bd.

📒 Files selected for processing (4)
  • CMakeLists.txt
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
include/**/*.{hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Files:

  • include/gpu_utils.hpp
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
🧠 Learnings (15)
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • src/mg_cuda_kernels.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • src/mg_cuda_kernels.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Check return values and use exceptions for error conditions; provide informative error messages

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Tests must pass in both Debug and Release builds; if different behavior is needed, investigate the root cause rather than accept build-type dependent behavior

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Add tests for new features; update tests when changing behavior; do not commit broken tests

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Do not relax tolerances for different compilers; tests should be numerically robust across all compilers

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Always call `set_body_force()` for driven flows; initialize velocity field before solving; set turbulence model before first solver step

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to app/**/*.cpp : Use `solve_steady_with_snapshots()` for automatic VTK output; specify `num_snapshots` in config; files are numbered sequentially plus final

Applied to files:

  • CMakeLists.txt
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/test_physics_validation*.cpp : Tests must verify the Navier-Stokes solver using comprehensive physics validation including Poiseuille Flow, Divergence-Free Constraint, Momentum Balance, Channel Symmetry, Cross-Model Consistency, and Sanity Checks

Applied to files:

  • CMakeLists.txt
🧬 Code graph analysis (1)
src/mg_cuda_kernels.cpp (2)
include/poisson_solver_multigrid.hpp (15)
  • Nz (120-120)
  • level (87-87)
  • level (88-88)
  • level (156-156)
  • level (157-157)
  • level (158-158)
  • level (161-161)
  • level (162-162)
  • level (163-163)
  • level (169-169)
  • level (174-174)
  • level (176-176)
  • level (197-197)
  • nu1 (150-150)
  • iterations (166-166)
include/mg_cuda_kernels.hpp (16)
  • stream (66-66)
  • stream (87-87)
  • stream (90-90)
  • stream (93-93)
  • stream (96-96)
  • stream (99-99)
  • stream (290-290)
  • stream (320-320)
  • stream (323-323)
  • stream (326-326)
  • graph_exec_ (69-69)
  • graph_exec_ (293-293)
  • launch_copy (167-170)
  • level (126-126)
  • level (130-130)
  • level (133-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (17)
include/gpu_utils.hpp (2)

248-255: LGTM! Proper null-check and OpenMP 5.1 device pointer retrieval.

The get_device_ptr template correctly handles nullptr input and uses the standard omp_get_mapped_ptr for device pointer lookup. Documentation clearly explains the behavior.


260-267: LGTM! Sync implementation with proper CPU fallback.

Using #pragma omp taskwait is the correct approach to synchronize deferred target tasks. The CPU fallback as a no-op is appropriate since there's nothing to synchronize without GPU offload.

src/mg_cuda_kernels.cpp (5)

14-23: LGTM! Robust CUDA error handling macro.

The CUDA_CHECK macro provides excellent error reporting with file and line information, wrapped in a do-while(0) for safe macro usage. Throwing std::runtime_error with descriptive messages aligns with the coding guidelines for error handling.


32-69: LGTM! Well-structured Chebyshev smoother kernel.

The 3D Chebyshev smoother correctly implements the weighted Jacobi update with proper ghost cell indexing. The use of __restrict__ pointers and precomputed inverse coefficients is good for performance.


129-280: Solid boundary condition kernel with race avoidance.

The BC kernel correctly uses early returns to skip cells owned by other faces (lines 200, 222, 244, 266), which prevents write races at edges and corners. This matches the fix mentioned in the commit messages for "boundary write-race fixes in bc_3d_kernel."


574-587: LGTM! Proper CUDA Graph resource cleanup.

The destructor correctly calls destroy(), and destroy() properly handles the CUDA Graph resources by destroying graph_exec_ before graph_ and setting both to nullptr. This prevents use-after-free and double-free issues.


761-774: LGTM! Proper V-cycle graph cleanup with reset.

The destructor and destroy() method correctly handle graph resources. Setting graph_exec_ = nullptr after destroy (line 768) addresses the "reset of vcycle_graph_ after destroy()" fix mentioned in the commit messages.

src/poisson_solver_multigrid.cpp (7)

82-88: LGTM! Correct invalidation of V-cycle graph on BC change.

When boundary conditions change, the captured CUDA graph becomes invalid because it was captured with specific BC kernels. Destroying and resetting the graph forces recapture with new BCs on next use.


958-1031: Excellent fused residual + norm computation.

This single-pass implementation computes the residual and both L∞ and L2 norms together, avoiding the overhead of separate compute_residual() + compute_max_residual() calls. The use of OpenMP reductions (reduction(max: max_res) reduction(+: sum_sq)) is correct for GPU offload.


1401-1417: LGTM! Correct nullspace detection and handling.

The logic correctly identifies that a nullspace exists when no Dirichlet BC is present on any face. The fix_nullspace method properly subtracts the mean and re-applies BCs to ensure ghost cells remain consistent.


1636-1664: Robust convergence logic with proper edge case handling.

The convergence checking correctly:

  1. Uses + 1e-30 to prevent division by zero when norms are zero
  2. Checks cfg.tol_rhs > 0.0 before applying L∞ safety cap (addresses the bug fix mentioned in commits)
  3. Maintains backward compatibility with legacy cfg.tol

The L∞ safety cap (lines 1651-1659) is a good addition to prevent L2 convergence from hiding localized divergence spikes.


2166-2181: LGTM! Proper device pointer handling with validation.

The code correctly:

  1. Uses gpu::get_device_ptr() to convert host pointers to device pointers for CUDA kernels
  2. Uses tmp_ptrs_[lvl] directly as it's already a device pointer from omp_target_alloc
  3. Validates all device pointers before graph capture and disables graph mode if any are null

This defensive approach prevents cryptic CUDA errors from invalid pointers.


1785-1817: Correct synchronization for mixed CUDA Graph + OpenMP environment.

The cudaDeviceSynchronize() call before computing b_l2_ is necessary because:

  1. CUDA Graphs and OpenMP target regions may use different streams
  2. The OpenMP target reduction needs all prior GPU work complete

The one-time f-data download for adaptive mode startup is acceptable overhead.


2199-2236: No changes needed—code is correct as implemented.

ompx_get_cuda_stream is properly documented as an NVHPC-provided function in the file comment (line 46) explaining that it's available when compiling with OpenMP target offload support. The function is correctly guarded by #ifdef __NVCOMPILER and #ifdef USE_GPU_OFFLOAD, and the comment documents its purpose and return type. No explicit declaration is required—it's automatically available from the NVHPC runtime.

CMakeLists.txt (3)

251-254: LGTM! Correct conditional compilation of CUDA kernels.

The mg_cuda_kernels.cpp is correctly added only when USE_GPU_OFFLOAD is enabled, matching the #ifdef USE_GPU_OFFLOAD guard in the source file.


517-535: LGTM! Well-organized test and benchmark additions.

The test structure is well-organized:

  • MG benchmarks (lines 518-525) are CPU/GPU agnostic and always built
  • GPU-specific benchmarks (bench_mg_cuda_graphs, test_vcycle_graph_stress) are correctly guarded by USE_GPU_OFFLOAD
  • The gpu label on VCycleGraphStressTest allows filtering GPU tests in CI

341-355: LGTM! Profiling executables properly linked.

The new profiling drivers (profile_comprehensive, profile_large_grid, profile_kernels) are correctly linked against nn_cfd_core, inheriting all necessary compile definitions and dependencies.

sbryngelson and others added 2 commits January 9, 2026 08:32
- Add GPU_PARALLEL_FOR_ASYNC fallback for CPU builds (was missing from #else)
- Extract Chebyshev eigenvalue bounds as named constants:
  - CHEBYSHEV_LAMBDA_MIN = 0.05, CHEBYSHEV_LAMBDA_MAX = 1.95
  - Documents the math: eigenvalues of D^{-1}A for 7-pt Laplacian in (0,2)
  - Ensures consistency across mg_cuda_kernels.cpp and poisson_solver_multigrid.cpp

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Update ghost-layer edge ownership checks to handle Ng >= 1:
- Y-faces: i == 0 || i == Nx+Ng  →  i < Ng || i >= Nx+Ng
- Z-faces: same pattern for both i and j dimensions

Previous code only worked correctly for Ng=1. New code correctly
skips all ghost-layer cells owned by higher-precedence faces.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In @include/gpu_utils.hpp:
- Around line 257-263: The comment for inline void sync() is misleading because
omp taskwait only synchronizes OpenMP target tasks, not direct CUDA launches;
update the docstring for sync() to state that it only waits for deferred OpenMP
target tasks (pragma omp taskwait) and does not affect CUDA kernels launched
directly (e.g., from mg_cuda_kernels.cpp), and add guidance that callers must
call cudaStreamSynchronize or cudaDeviceSynchronize when using direct CUDA
launches to ensure GPU work completion.

In @src/mg_cuda_kernels.cpp:
- Around line 784-803: The initialize() method must refuse to
destroy/reinitialize a graph while it is currently executing: add an
atomic<bool> flag named is_executing_ to the class, document the thread-safety
requirement on initialize() in the header, and in initialize() check
is_executing_ and throw a std::runtime_error if true; in execute(cudaStream_t)
set is_executing_ = true before launching the graph and clear it only after the
stream has completed (e.g., by synchronizing the stream or using a CUDA event to
detect completion) so initialize()/destroy() cannot run concurrently; also
protect destroy()/initialize() with the same guard (or a mutex) to avoid races
and include informative error messages mentioning
initialize()/execute()/is_executing_.

In @src/poisson_solver_multigrid.cpp:
- Around line 2172-2187: The null-device-pointer check in the loop over configs
(configs[lvl].u/f/r/tmp populated from gpu::get_device_ptr and
tmp_ptrs_/u_ptrs_/f_ptrs_/r_ptrs_) must not silently disable graphs; instead
throw a std::runtime_error containing the level and which pointer was null (and
ideally the pointer values) to surface the GPU mapping failure to the caller, or
if a graceful fallback is required, replace the current
std::cerr+use_vcycle_graph_ assignment with a clear fatal log that documents
"Falling back to non-graphed path" and sets a dedicated warning flag; locate the
check around use_vcycle_graph_ and replace the silent fallback with the
exception or enhanced warning as appropriate.
🧹 Nitpick comments (8)
include/gpu_utils.hpp (2)

245-255: Add documentation and consider error reporting for get_device_ptr.

This public API function lacks documentation explaining:

  • OpenMP 5.1 requirement for omp_get_mapped_ptr
  • Return value semantics (nullptr for unmapped vs null input)
  • Usage examples

Additionally, the function silently returns nullptr for unmapped pointers, which can lead to difficult-to-debug issues. Consider adding a debug-mode assertion or optional error reporting.

As per coding guidelines: "Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions."

📝 Proposed documentation
+/// Get device pointer for an OpenMP-mapped host pointer
+/// @param host_ptr Host pointer that was mapped via omp target enter data
+/// @return Device pointer, or nullptr if host_ptr is null or not mapped
+/// @note Requires OpenMP 5.1+ for omp_get_mapped_ptr support
+/// @warning Returns nullptr for unmapped pointers - ensure pointer is mapped before calling
 template<typename T>
 inline T* get_device_ptr(T* host_ptr) {
     if (host_ptr == nullptr) return nullptr;
     int device = omp_get_default_device();
     void* dev_ptr = omp_get_mapped_ptr(host_ptr, device);
-    // omp_get_mapped_ptr returns nullptr if pointer is not mapped
+    #ifndef NDEBUG
+    if (dev_ptr == nullptr) {
+        // In debug mode, warn about unmapped pointer access
+        std::cerr << "[gpu_utils] WARNING: get_device_ptr called on unmapped pointer\n";
+    }
+    #endif
     return static_cast<T*>(dev_ptr);
 }

285-289: Document that CPU fallback is synchronous, not async.

The comment at line 313 states "async/sync pattern still works" on CPU, but the CPU fallback is a blocking #pragma omp parallel for, not truly asynchronous. This could mislead users expecting async behavior.

📝 Proposed clarification
 // CPU fallback for async - just a regular parallel for (no async semantics on CPU)
-// Note: gpu::sync() is a no-op on CPU, so async/sync pattern still works
+// Note: CPU path is synchronous (blocking). gpu::sync() is a no-op on CPU,
+//       so the async/sync pattern compiles but has no async behavior.
 #define GPU_PARALLEL_FOR_ASYNC(var, start, end) \
     _Pragma("omp parallel for") \
     for (int var = start; var < end; ++var)

Also applies to: 312-316

src/poisson_solver_multigrid.cpp (5)

66-67: Move duplicated constants to shared header.

These constants are duplicated in mg_cuda_kernels.cpp (lines 35-36), requiring manual synchronization. While the comment acknowledges this is "for CPU build isolation," a shared header would be more maintainable.

♻️ Proposed refactor

Create a new header (e.g., include/mg_common.hpp):

#pragma once

namespace nncfd {
namespace mg {
    /// Conservative eigenvalue bounds for D^{-1}*A for Chebyshev smoothing
    constexpr double CHEBYSHEV_LAMBDA_MIN = 0.05;
    constexpr double CHEBYSHEV_LAMBDA_MAX = 1.95;
}
}

Then include in both files:

+#include "mg_common.hpp"
-constexpr double CHEBYSHEV_LAMBDA_MIN = 0.05;
-constexpr double CHEBYSHEV_LAMBDA_MAX = 1.95;
+using mg::CHEBYSHEV_LAMBDA_MIN;
+using mg::CHEBYSHEV_LAMBDA_MAX;

1641-1665: Consider using larger epsilon or std::max for division safety.

Lines 1646, 1649, and 1652 add 1e-30 to prevent division by zero in relative tolerance checks. However, if the norm values are on the order of 1e-10 to 1e-15, adding 1e-30 provides no protection.

Consider using a more robust approach:

♻️ Proposed improvement
+            const double eps = 1e-15;  // Reasonable epsilon for double precision
             if (cfg.tol_rhs > 0.0 && r_norm <= cfg.tol_rhs * (b_norm + 1e-30)) {
+            if (cfg.tol_rhs > 0.0 && r_norm <= cfg.tol_rhs * std::max(b_norm, eps)) {
                 converged = true;  // RHS-relative tolerance met
             }
             if (cfg.tol_rel > 0.0 && r_norm <= cfg.tol_rel * (r0_norm + 1e-30)) {
+            if (cfg.tol_rel > 0.0 && r_norm <= cfg.tol_rel * std::max(r0_norm, eps)) {
                 converged = true;  // Initial-residual relative tolerance met
             }

Or check for zero explicitly:

if (cfg.tol_rhs > 0.0 && b_norm > 0.0 && r_norm / b_norm <= cfg.tol_rhs) {
    converged = true;
}

1531-1576: Consider using NaN for uncomputed residuals in fixed-cycle mode.

Lines 1543-1549 set all residual tracking variables to 0.0 when convergence checks are skipped. However, 0.0 might be misinterpreted as "converged perfectly" rather than "not computed."

Consider using NaN to make it explicit these values are invalid:

♻️ Proposed alternative
-        // Set residual to 0 to indicate we didn't compute it
-        residual_ = 0.0;
-        residual_l2_ = 0.0;
-        r0_ = 0.0;
-        r0_l2_ = 0.0;
-        b_inf_ = 0.0;
-        b_l2_ = 0.0;
+        // Set residuals to NaN to indicate they weren't computed
+        const double nan_val = std::numeric_limits<double>::quiet_NaN();
+        residual_ = nan_val;
+        residual_l2_ = nan_val;
+        r0_ = nan_val;
+        r0_l2_ = nan_val;
+        b_inf_ = nan_val;
+        b_l2_ = nan_val;

Or at minimum, document that these values are invalid:

// Note: residual values are not computed in fixed-cycle mode
residual_ = 0.0;  // Invalid - not computed

2214-2233: Consider warning on every null stream occurrence, not just first.

Lines 2225-2230 use a static bool to warn only once about null OpenMP stream. This pattern suppresses subsequent warnings, which could hide repeated failures or changing conditions (e.g., after device reset).

Consider one of these approaches:

♻️ Option 1: Warn every time (recommended for debugging)
-        static bool warned = false;
-        if (!warned) {
             std::cerr << "[MG] WARNING: OpenMP CUDA stream is null - falling back to non-graphed V-cycle\n"
                       << "    This may indicate a runtime issue. Performance will be degraded.\n";
-            warned = true;
-        }
♻️ Option 2: Warn periodically with counter
static int warn_count = 0;
if (warn_count < 10 || warn_count % 100 == 0) {
    std::cerr << "[MG] WARNING (occurrence " << warn_count 
              << "): OpenMP CUDA stream is null...\n";
}
warn_count++;

The current behavior may be appropriate for production, but makes debugging harder.


1796-1823: Consider computing b_l2 on GPU to avoid D→H transfer.

Lines 1810-1822 transfer the entire RHS array from device to host just to compute b_l2 on the CPU. While the comment acknowledges this is "one-time overhead," it's still an expensive D→H transfer that could be avoided.

Use a GPU reduction instead:

♻️ Proposed GPU reduction
-            // Sync f data from device to get b_l2
-            // This is a one-time overhead for adaptive mode startup
-            #pragma omp target update from(f_ptrs_[0][0:f_size])
-
-            // Compute on CPU (f data is now in host buffer)
             double b_sum_sq = 0.0;
-            for (int k = Ng; k < Nz + Ng; ++k) {
-                for (int j = Ng; j < Ny + Ng; ++j) {
-                    for (int i = Ng; i < Nx + Ng; ++i) {
-                        double val = f_ptrs_[0][k * plane_stride + j * stride + i];
-                        b_sum_sq += val * val;
-                    }
+            #pragma omp target teams distribute parallel for collapse(3) \
+                map(present: f_ptrs_[0][0:f_size]) reduction(+: b_sum_sq)
+            for (int k = Ng; k < Nz + Ng; ++k) {
+                for (int j = Ng; j < Ny + Ng; ++j) {
+                    for (int i = Ng; i < Nx + Ng; ++i) {
+                        int idx = k * plane_stride + j * stride + i;
+                        double val = f_ptrs_[0][idx];
+                        b_sum_sq += val * val;
+                    }
                 }
             }
             b_l2_ = std::sqrt(b_sum_sq);

This eliminates the D→H transfer and is more consistent with the GPU-resident philosophy.

Based on learnings: "Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible"

src/mg_cuda_kernels.cpp (1)

83-135: Document Ng=1 assumption in periodic wrap logic.

The periodic boundary wrap logic (lines 103-108) assumes Ng = 1. While this is correct for the current code (all MG levels use Ng=1), the assumption should be documented to prevent future issues if ghost layer count changes.

📝 Proposed documentation
 __global__ void chebyshev_3d_periodic_kernel(
     double* __restrict__ u,
     const double* __restrict__ f,
     double* __restrict__ tmp,
     int Nx, int Ny, int Nz, int Ng,
     double inv_dx2, double inv_dy2, double inv_dz2,
     double inv_coeff, double omega)
 {
+    // NOTE: Wrap logic assumes Ng=1. For Ng>1, wrap logic needs adjustment.
     // Global thread indices (interior points: 1..Nx, 1..Ny, 1..Nz with Ng=1)
     int i = blockIdx.x * blockDim.x + threadIdx.x + Ng;

Or add a runtime assertion in the launch function:

void launch_chebyshev_3d_periodic(...) {
    assert(Ng == 1 && "Periodic kernel only supports Ng=1");
    // ...
}
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ed4bd and 0e14060.

📒 Files selected for processing (3)
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (4)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_multigrid.cpp
  • src/mg_cuda_kernels.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_multigrid.cpp
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_multigrid.cpp
  • src/mg_cuda_kernels.cpp
include/**/*.{hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Files:

  • include/gpu_utils.hpp
🧠 Learnings (10)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Check return values and use exceptions for error conditions; provide informative error messages

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • include/gpu_utils.hpp
  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • src/poisson_solver_multigrid.cpp
  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • src/mg_cuda_kernels.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (5)
src/poisson_solver_multigrid.cpp (2)

1407-1423: LGTM! Well-structured nullspace handling.

The nullspace detection and fixing logic is correct and well-documented. The approach of checking for Dirichlet BCs to determine nullspace presence is standard, and reapplying BCs after mean subtraction ensures ghost cell consistency.


964-1037: No action required - L2 norm calculation is correct.

The unnormalized form (sqrt(sum_sq) without dividing by the number of points) is intentional and mathematically sound. Relative tolerances in lines 1637-1640 compare ||r||_2 / ||b||_2, where the normalization factors cancel out in the ratio. Absolute tolerances compare directly against the unnormalized norm, which serves as a valid absolute threshold. This pattern is applied consistently for both b_l2_ and residual norms across all code paths (lines 1619, 1822, 1935, 1036).

src/mg_cuda_kernels.cpp (3)

14-23: LGTM! CUDA error checking macro is well-structured.

The CUDA_CHECK macro provides consistent error handling with informative messages. The do-while(0) pattern is correct for macro safety.

Minor note: The macro assumes exception support. If this code might be compiled with -fno-exceptions, consider adding a fallback:

#ifdef __cpp_exceptions
    throw std::runtime_error(...);
#else
    std::cerr << "CUDA error: " << cudaGetErrorString(err) << "\n";
    std::abort();
#endif

139-290: LGTM! Boundary kernel is correct with good edge handling.

The unified BC kernel correctly handles all 6 faces with proper edge skipping to prevent race conditions (lines 210, 232, 254, 276). The early returns for skipped cells cause some thread divergence, but this is an acceptable trade-off for code clarity.

Optional performance note: For very large grids, separate kernels per face might be slightly faster (no divergence), but the current unified approach is more maintainable and the performance difference is likely negligible compared to smoother cost.


655-658: Clarify if final BC application is needed for periodic case.

Line 658 applies BCs unconditionally with the comment "ALWAYS needed for MG operations outside the smoother." However, when all_periodic is true, the fused kernel at line 644 already handles periodic BCs inline.

Questions:

  1. For the all-periodic case, does the final BC application at line 658 do anything (are ghost cells already correct)?
  2. Is this a defensive programming choice (always apply BCs) or is there a subtle reason ghost cells need re-filling?
  3. Could this be optimized to skip the final BC for all-periodic?

If it's defensive programming, consider adding a comment:

// Final BC application - ALWAYS needed for MG operations outside the smoother
// (Even for all-periodic case, ensures ghost cells are correct for
//  compute_residual/restrict/prolongate which use standard neighbor indexing)
launch_bc_kernel(stream);

If it's truly redundant for periodic, consider:

if (!all_periodic) {
    launch_bc_kernel(stream);
}

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @src/mg_cuda_kernels.cpp:
- Around line 665-672: The function
CudaSmootherGraph::launch_chebyshev_iteration has an unused parameter k causing
a compiler warning; mark the parameter as maybe-unused by changing the function
signature to use the [[maybe_unused]] attribute on k (i.e., [[maybe_unused]] int
k) so the compiler knows the unused parameter is intentional, leaving the rest
of the implementation calling launch_chebyshev_3d unchanged.
- Around line 765-942: Implement the missing method
CudaVCycleGraph::debug_print_pointers() const in this translation unit with the
exact signature declared in include/mg_cuda_kernels.hpp; place it alongside the
other CudaVCycleGraph methods (e.g., before the end of the CudaVCycleGraph
section). The body should be const-correct and simply iterate levels_ and print
or log the device pointer values (cfg.u, cfg.f, cfg.r, cfg.tmp, etc.) along with
level index (using printf or std::cout) so the linker symbol is satisfied and
you can inspect pointers at runtime; ensure you include any needed headers for
printing and that the method is non-intrusive (no state mutation).
🧹 Nitpick comments (3)
src/mg_cuda_kernels.cpp (3)

460-460: Document the 2D handling special case.

The zero-out of inv_dz2 for 2D grids (when Nz == 1) disables the z-direction contribution to the Laplacian stencil. While correct, this implicit handling of dimensional reduction deserves a brief explanatory comment for maintainability.


683-692: Document the hardcoded Dirichlet boundary value assumption.

The Dirichlet boundary value is hardcoded to 0.0, which is appropriate for pressure fields but may not be correct for other field types. Add a comment clarifying that this implementation is specific to the pressure Poisson equation.


827-827: Make coarse solver iterations configurable.

The coarse grid iterations are hardcoded to 8 with a TODO-like comment. Consider making this a configurable parameter in the VCycleLevelConfig or CudaVCycleGraph::initialize() signature for better flexibility and tuning.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e14060 and f80cd30.

📒 Files selected for processing (1)
  • src/mg_cuda_kernels.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/mg_cuda_kernels.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/mg_cuda_kernels.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/mg_cuda_kernels.cpp
🧠 Learnings (8)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Avoid overly strict floating-point comparisons in tests; use appropriate tolerances based on algorithm rather than exact equality checks (e.g., avoid `==` for doubles)

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : Do not add platform-specific tolerances with `#ifdef __APPLE__` or similar platform checks in tests; fix the root cause instead

Applied to files:

  • src/mg_cuda_kernels.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/mg_cuda_kernels.cpp
🧬 Code graph analysis (1)
src/mg_cuda_kernels.cpp (1)
include/mg_cuda_kernels.hpp (25)
  • stream (66-66)
  • stream (87-87)
  • stream (90-90)
  • stream (93-93)
  • stream (96-96)
  • stream (99-99)
  • stream (290-290)
  • stream (320-320)
  • stream (323-323)
  • stream (326-326)
  • config (59-62)
  • launch_copy (167-170)
  • levels (118-123)
  • levels (282-287)
  • smoother_graphs_ (136-136)
  • level (126-126)
  • level (130-130)
  • level (133-133)
  • CudaVCycleGraph (273-273)
  • CudaVCycleGraph (274-274)
  • fingerprint_ (301-301)
  • launch_residual_3d (184-188)
  • launch_restrict_3d (192-197)
  • launch_zero (173-176)
  • launch_prolongate_3d (201-206)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: gpu-tests
🔇 Additional comments (2)
src/mg_cuda_kernels.cpp (2)

639-640: Remove M_PI portability comment—the codebase successfully uses M_PI without safeguards.

M_PI is used throughout the codebase (src/poisson_solver_fft.cpp, src/poisson_solver_multigrid.cpp, etc.) without _USE_MATH_DEFINES or fallback definitions, and the project compiles successfully with C++17. Modern C++ standard library implementations provide M_PI in <cmath> without additional configuration. This is not an issue in this codebase.

Likely an incorrect or invalid review comment.


922-923: This hardcoding is correct by design. The MultigridPoissonSolver is specifically designed for solving the pressure Poisson equation in incompressible flow, where the Dirichlet boundary value must be zero (pressure has an arbitrary reference). Non-zero Dirichlet values would violate the incompressible flow constraint and lack physical meaning. The solver is not intended as a general-purpose Poisson solver. The comment at line 690 documents this: // Dirichlet value (0 for pressure). No changes needed.

Comment on lines +665 to +672
void CudaSmootherGraph::launch_chebyshev_iteration(cudaStream_t stream, int k, double omega) {
launch_chebyshev_3d(
stream,
config_.u, config_.f, config_.tmp,
config_.Nx, config_.Ny, config_.Nz, config_.Ng,
config_.dx2, config_.dy2, config_.dz2, config_.coeff,
omega);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix unused parameter warning.

The parameter k is unused in the function body, which will trigger a compiler warning. Based on the coding guidelines, use the [[maybe_unused]] attribute.

🔧 Proposed fix
-void CudaSmootherGraph::launch_chebyshev_iteration(cudaStream_t stream, int k, double omega) {
+void CudaSmootherGraph::launch_chebyshev_iteration(cudaStream_t stream, [[maybe_unused]] int k, double omega) {
     launch_chebyshev_3d(
         stream,
         config_.u, config_.f, config_.tmp,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void CudaSmootherGraph::launch_chebyshev_iteration(cudaStream_t stream, int k, double omega) {
launch_chebyshev_3d(
stream,
config_.u, config_.f, config_.tmp,
config_.Nx, config_.Ny, config_.Nz, config_.Ng,
config_.dx2, config_.dy2, config_.dz2, config_.coeff,
omega);
}
void CudaSmootherGraph::launch_chebyshev_iteration(cudaStream_t stream, [[maybe_unused]] int k, double omega) {
launch_chebyshev_3d(
stream,
config_.u, config_.f, config_.tmp,
config_.Nx, config_.Ny, config_.Nz, config_.Ng,
config_.dx2, config_.dy2, config_.dz2, config_.coeff,
omega);
}
🤖 Prompt for AI Agents
In @src/mg_cuda_kernels.cpp around lines 665 - 672, The function
CudaSmootherGraph::launch_chebyshev_iteration has an unused parameter k causing
a compiler warning; mark the parameter as maybe-unused by changing the function
signature to use the [[maybe_unused]] attribute on k (i.e., [[maybe_unused]] int
k) so the compiler knows the unused parameter is intentional, leaving the rest
of the implementation calling launch_chebyshev_3d unchanged.

sbryngelson and others added 2 commits January 9, 2026 08:45
Null device pointers indicate a bug in GPU buffer initialization,
not a recoverable runtime condition. Silent fallback to non-graphed
path would:
- Waste hours of compute time on 120x slower path
- Hide the underlying bug from the user
- Bury the error message in batch job logs

Fail fast is correct here - users need to know immediately if
their GPU setup is broken.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
The previous comment "wait for all pending kernels" was misleading.
omp taskwait only synchronizes OpenMP deferred target tasks (nowait),
not direct CUDA kernel launches which use cudaStreamSynchronize.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @src/poisson_solver_multigrid.cpp:
- Around line 1790-1830: Wrap both cudaDeviceSynchronize() calls in conditional
compilation (#ifdef USE_GPU_OFFLOAD / #endif), check their return values and
throw a runtime_error (or propagate via throw) on non-success; replace bare
cudaDeviceSynchronize() with a helper that does: cudaError_t err =
cudaDeviceSynchronize(); if (err != cudaSuccess) throw
std::runtime_error(std::string("cudaDeviceSynchronize failed:
")+cudaGetErrorString(err)). Ensure the sync_if_graphed lambda is still guarded
for graph validity but call the synchronization unconditionally before each
compute_residual_and_norms() invocation in the adaptive while loop (i.e., move
the sync_if_graphed() call to immediately precede compute_residual_and_norms()),
so prior GPU work is always completed regardless of use_graph, and reference the
lambda name sync_if_graphed, the loop invoking compute_residual_and_norms(), and
run_cycles()/vcycle_graph_ when making these changes.
- Around line 2179-2187: The null-pointer checks are fine but the GPU buffers
must use the RAII wrappers from gpu_utils.hpp instead of manual
omp_target_alloc/free and raw mapped pointers: replace the scratch device
allocations/temp pointer management (tmp_ptrs_ / configs[].tmp) with
DeviceOnlyArray, and wrap mapped per-level arrays (u_ptrs_, f_ptrs_, r_ptrs_ /
configs[].u, configs[].f, configs[].r) with DeviceArray; update
initialize_gpu_buffers() to allocate those DeviceOnlyArray/DeviceArray instances
and expose device pointers via their accessors to fill configs, remove manual
omp_target_free calls and raw pointer bookkeeping, and adjust any null checks to
use the wrappers' validity/accessor methods.
🧹 Nitpick comments (3)
src/poisson_solver_multigrid.cpp (3)

59-68: Consider centralizing Chebyshev eigenvalue bounds in a shared header.

The constants CHEBYSHEV_LAMBDA_MIN and CHEBYSHEV_LAMBDA_MAX are duplicated in mg_cuda_kernels.cpp. While the comment acknowledges this, duplication creates a maintenance risk if values need adjustment. Consider moving these to a shared header (e.g., mg_common.hpp) included by both files to ensure consistency.


1531-1576: Consider computing final residual in fixed-cycle mode for diagnostics.

Fixed-cycle mode sets all residual values to 0 (Lines 1543-1549), which prevents post-solve diagnostics from assessing solution quality. For debugging and validation, consider optionally computing the final residual even when not checking convergence during iteration.

Proposed enhancement for diagnostic support

After the fixed-cycle loop (Line 1541), optionally compute residuals:

         for (int cycle = 0; cycle < num_cycles; ++cycle) {
             vcycle(0, nu1, nu2, degree);
         }
 
-        // Set residual to 0 to indicate we didn't compute it
-        residual_ = 0.0;
-        residual_l2_ = 0.0;
-        r0_ = 0.0;
-        r0_l2_ = 0.0;
-        b_inf_ = 0.0;
-        b_l2_ = 0.0;
+        // Optionally compute final residual for diagnostics
+        if (cfg.compute_final_residual) {
+            compute_residual_and_norms(0, residual_, residual_l2_);
+        } else {
+            residual_ = 0.0;
+            residual_l2_ = 0.0;
+        }
+        r0_ = 0.0;
+        r0_l2_ = 0.0;
+        b_inf_ = 0.0;
+        b_l2_ = 0.0;

2004-2082: Consider using OmpDeviceBuffer wrapper for GPU memory management.

The current implementation uses manual OpenMP target pragmas (target enter data, target exit data) and explicit omp_target_alloc/omp_target_free calls for GPU buffer management. Per coding guidelines, consider refactoring to use the OmpDeviceBuffer wrapper for improved safety and maintainability.

Benefits:

  • RAII semantics ensure cleanup on exceptions
  • Reduced boilerplate for map clauses
  • Consistent buffer lifecycle management

Based on learnings, OmpDeviceBuffer wrapper should be used for GPU buffer management instead of manual memory management.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f80cd30 and 25aab56.

📒 Files selected for processing (1)
  • src/poisson_solver_multigrid.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

**/*.cpp: Use [[maybe_unused]] attribute for intentionally unused variables in assertions instead of suppressing warnings with compiler flags
Use RAII for resource management with smart pointers or RAII wrappers; avoid manual new/delete or malloc/free
Check return values and use exceptions for error conditions; provide informative error messages
Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
Always call set_body_force() for driven flows; initialize velocity field before solving; set turbulence model before first solver step
For GPU offload, verify USE_GPU_OFFLOAD is defined and check omp_get_num_devices() > 0 at runtime

Files:

  • src/poisson_solver_multigrid.cpp
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • src/poisson_solver_multigrid.cpp
src/**/*.cpp

📄 CodeRabbit inference engine (.cursorrules)

src/**/*.cpp: Profile before optimizing and document complexity assumptions; use appropriate data structures
Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Files:

  • src/poisson_solver_multigrid.cpp
🧠 Learnings (8)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to src/**/*.cpp : Ensure boundary conditions are consistent across solver components; handle periodic BCs specially in Poisson solver

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Check return values and use exceptions for error conditions; provide informative error messages

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/**/*.cpp : For iterative solvers in tests, check residual convergence rather than exact iteration count

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • src/poisson_solver_multigrid.cpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • src/poisson_solver_multigrid.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (6)
src/poisson_solver_multigrid.cpp (6)

84-118: LGTM! Graph invalidation on BC changes is correct.

The BC setters correctly invalidate the V-cycle CUDA graph when boundary conditions change, ensuring the graph is recaptured with updated parameters. The explicit destroy() call before reset() provides clear cleanup semantics.

Verify that vcycle_graph_->destroy() is exception-safe. If it throws, the unique_ptr remains in a partially destroyed state. Based on learnings, check return values and use exceptions for error conditions.


907-962: LGTM! Residual computation refactored for clarity.

The unified CPU/GPU implementation using raw pointers improves consistency and maintainability. The added NVTX scope aids profiling without changing computation logic.


964-1037: LGTM! Fused residual+norm computation is efficient.

The single-pass computation of residual, L∞ norm, and L2 norm reduces memory bandwidth requirements, which is especially beneficial for GPU execution. The OpenMP reductions correctly handle both CPU and GPU paths.

Verify that the L2 norm semantics match the convergence criteria expectations. The current implementation returns sqrt(sum(r^2)) without normalization by grid size. Ensure this is consistent with the relative tolerance checks in solve() at Lines 1637-1650.


1407-1423: LGTM! Nullspace handling logic is correct.

The nullspace detection correctly identifies that Dirichlet BCs on any face eliminate the constant nullspace, while pure Neumann/Periodic systems require mean subtraction. The re-application of BCs after subtract_mean() is necessary since ghost cells become inconsistent.

Confirm that fix_nullspace() is only called after convergence, not during V-cycle iteration. Calling it mid-iteration would disrupt the multigrid correction process. Based on the code at Lines 1552, 1675, 1867, and 1991, this appears correct, but verify there are no other call sites.


2224-2233: No changes needed; the error handling is correct.

The PR commit "Fix: throw on null device pointers instead of silent fallback" specifically targets null device pointers (buffer initialization failures at lines 2181–2187), not null CUDA streams. These are two distinct error conditions requiring different responses.

Null device pointers indicate a fatal initialization bug → throw immediately to surface the issue (line 2182).

Null CUDA streams indicate a runtime condition where OpenMP's stream is unavailable → graceful fallback is appropriate, as documented in the code comments (lines 2220–2223): "A null stream would launch on CUDA default stream, causing potential race conditions... Instead of crashing in production, fall back to the non-graphed path which is slower but correct."

Both approaches are correctly implemented.

Likely an incorrect or invalid review comment.


1641-1665: Convergence criteria fallback design handles zero-RHS cases through OR logic, not the 1e-30 guard alone.

The division-by-zero guards (+ 1e-30) are intentional. When b_norm approaches zero (as in pure Neumann problems with zero RHS tested in test_error_recovery.cpp), the tol_rhs criterion becomes ineffective (effectively r_norm <= tol_rhs * 1e-30), but the code relies on the fallback criteria: tol_abs (legacy cfg.tol) or tol_rel trigger instead. This is working as designed and tested.

However, the interdependence of these criteria could be clearer in comments. Consider adding a note that the multiple tolerance checks form a fallback chain for robustness with singular/zero-RHS systems.

The L∞ safety cap logic is correct as-is.

- Update POISSON_SOLVER_GUIDE.md with detailed CUDA Graph section
- Add V-cycle CUDA Graph subsection to README GPU Acceleration
- Document performance impact, configuration, requirements, fallbacks
- Add Chebyshev smoothing documentation with eigenvalue bounds

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @docs/POISSON_SOLVER_GUIDE.md:
- Around line 182-193: The fenced code block containing the vcycle(level)
pseudocode is missing a language specifier; update that fenced block to include
a language tag (e.g., add "text" or "pseudo") after the opening ``` so the block
starting with "vcycle(level):" is ```text (or another appropriate tag) to enable
proper syntax highlighting and satisfy markdownlint.

In @include/gpu_utils.hpp:
- Around line 245-255: Add a Doxygen-style documentation comment above the
public template function get_device_ptr<T>(T* host_ptr) describing the parameter
host_ptr, the return value (device pointer or nullptr when host_ptr is nullptr
or not mapped), and any side effects (queries default device using
omp_get_default_device() and uses omp_get_mapped_ptr() so OpenMP mapping is
required); mention thread-safety/constness implications and that it returns a
raw device pointer cast to T*. Keep the comment adjacent to the function
template declaration.
🧹 Nitpick comments (1)
include/gpu_utils.hpp (1)

257-267: Add formal documentation comment for the sync() function.

While the inline comments are helpful, the sync() function should have a proper documentation comment explaining its purpose, behavior, and limitations as a public API function.

📚 Suggested documentation
+/// Synchronize all OpenMP target tasks
+/// Waits for all deferred target regions (launched with nowait) to complete.
+/// Use this before reading GPU results back to host after async operations.
+/// @warning Only synchronizes OpenMP target tasks, NOT direct CUDA kernel launches.
+///          CUDA kernels (e.g., in mg_cuda_kernels.cpp) must use cudaStreamSynchronize.
+/// @note CPU fallback is a no-op (CPU parallel regions are synchronous by default)
 inline void sync() {

As per coding guidelines for include/**/*.{hpp,h}.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25aab56 and 52a3fd3.

📒 Files selected for processing (3)
  • README.md
  • docs/POISSON_SOLVER_GUIDE.md
  • include/gpu_utils.hpp
🧰 Additional context used
📓 Path-based instructions (3)
README.md

📄 CodeRabbit inference engine (.cursorrules)

Keep command-line options in README up to date; document new features and update examples when behavior changes

Files:

  • README.md
**/*.{cpp,hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

**/*.{cpp,hpp,h}: Fix all compiler warnings before pushing; do not suppress warnings with flags
Use const for read-only references and mark methods const if they don't modify state in C++ code
Use OmpDeviceBuffer wrapper for GPU buffer management instead of manual memory management
Add explanatory comments for WHY code does something, not WHAT it does; document numerical algorithms and non-obvious optimizations

Files:

  • include/gpu_utils.hpp
include/**/*.{hpp,h}

📄 CodeRabbit inference engine (.cursorrules)

Every public function needs a documentation comment explaining parameters, return values, side effects, and including usage examples for complex functions

Files:

  • include/gpu_utils.hpp
🧠 Learnings (7)
📓 Common learnings
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to README.md : Keep command-line options in README up to date; document new features and update examples when behavior changes

Applied to files:

  • README.md
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: For GPU-related changes, also run `./test_before_ci_gpu.sh` to validate GPU CI test suite locally including physics validation tests, turbulence model validation, and CPU/GPU consistency

Applied to files:

  • README.md
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to tests/test_physics_validation*.cpp : Tests must verify the Navier-Stokes solver using comprehensive physics validation including Poiseuille Flow, Divergence-Free Constraint, Momentum Balance, Channel Symmetry, Cross-Model Consistency, and Sanity Checks

Applied to files:

  • README.md
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.{cpp,hpp,h} : Use `OmpDeviceBuffer` wrapper for GPU buffer management instead of manual memory management

Applied to files:

  • include/gpu_utils.hpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : For GPU offload, verify `USE_GPU_OFFLOAD` is defined and check `omp_get_num_devices() > 0` at runtime

Applied to files:

  • include/gpu_utils.hpp
📚 Learning: 2026-01-07T15:09:29.821Z
Learnt from: CR
Repo: comp-physics/cfd-nn PR: 0
File: .cursorrules:0-0
Timestamp: 2026-01-07T15:09:29.821Z
Learning: Applies to **/*.cpp : Minimize CPU↔GPU transfers and keep frequently accessed data on GPU; batch operations when possible

Applied to files:

  • include/gpu_utils.hpp
🪛 markdownlint-cli2 (0.18.1)
docs/POISSON_SOLVER_GUIDE.md

182-182: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: gpu-tests
  • GitHub Check: build-and-test (ubuntu-latest, Release)
🔇 Additional comments (6)
docs/POISSON_SOLVER_GUIDE.md (3)

195-226: Well-documented CUDA Graph optimization feature.

The V-cycle CUDA Graph section provides comprehensive coverage of the new feature including workflow, performance impact, configuration, requirements, and fallback behavior. The documentation is clear and user-friendly.


227-240: Excellent technical documentation of Chebyshev smoothing.

The Chebyshev smoothing section provides clear technical justification for the algorithm choice, including eigenvalue bounds and GPU performance advantages.


185-192: Remove the incorrect coarsest level claim; V-cycle parameter changes (nu1=3, nu2=1) are already validated.

The coarsest level has always used 8 iterations (via Chebyshev smoothing), not 100. The 100-iteration parameter in solve_coarsest() is a legacy default that was never invoked in the V-cycle code path. The actual parameter changes—nu1=3 (pre-smoothing) and nu2=1 (post-smoothing)—are already validated across multiple benchmarks and convergence tests in the codebase (e.g., bench_mg_bc_sweep.cpp, bench_mg_tuning.cpp) which demonstrate 13% speedup and 10× better divergence reduction versus the previous configuration.

Likely an incorrect or invalid review comment.

include/gpu_utils.hpp (1)

285-289: LGTM - Well-commented async macro.

The GPU_PARALLEL_FOR_ASYNC macro is properly documented with inline comments explaining the async behavior and the requirement to call gpu::sync() before reading results on the host.

README.md (2)

635-667: Well-documented major feature in README.

The V-cycle CUDA Graph Optimization section appropriately highlights this significant performance feature in the README, with proper cross-references to the detailed guide. The documentation follows best practices by documenting new features and providing configuration examples.

Based on learnings: "Keep command-line options in README up to date; document new features and update examples when behavior changes."


635-667: Reminder: Run GPU test suite for this PR.

Since this PR introduces significant GPU-related changes (CUDA Graph acceleration), ensure you run the GPU test suite locally before merging:

./test_before_ci_gpu.sh

This validates GPU CI tests including physics validation, turbulence model validation, and CPU/GPU consistency checks.

Based on learnings: "For GPU-related changes, also run ./test_before_ci_gpu.sh to validate GPU CI test suite locally."

- Add CUDA_CHECK_SYNC macro to check cudaDeviceSynchronize return values
- Throws std::runtime_error with file:line on CUDA failures
- Consistent with fail-fast philosophy for GPU errors

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@comp-physics comp-physics deleted a comment from coderabbitai bot Jan 9, 2026
- Remove profiles/*.txt (Nsight Systems output with system-specific paths)
- Add profiles/ directory to .gitignore

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@coderabbitai
Copy link

coderabbitai bot commented Jan 9, 2026

Warning

Rate limit exceeded

@sbryngelson has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 13 minutes and 9 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 51b27c5 and 7036f5f.

📒 Files selected for processing (28)
  • .gitignore
  • CMakeLists.txt
  • README.md
  • app/main_taylor_green_3d.cpp
  • app/profile_comprehensive.cpp
  • docs/POISSON_SOLVER_GUIDE.md
  • docs/profiling_results_128cubed.md
  • include/config.hpp
  • include/gpu_utils.hpp
  • include/mg_cuda_kernels.hpp
  • include/poisson_solver.hpp
  • include/poisson_solver_multigrid.hpp
  • scripts/ci.sh
  • scripts/run_nsys_profiles.sh
  • src/config.cpp
  • src/gpu_kernels.cpp
  • src/mg_cuda_kernels.cpp
  • src/poisson_solver_fft.cpp
  • src/poisson_solver_fft2d.cpp
  • src/poisson_solver_multigrid.cpp
  • src/solver.cpp
  • tests/bench_fft_vs_mg.cpp
  • tests/bench_mg_bc_sweep.cpp
  • tests/bench_mg_cuda_graphs.cpp
  • tests/bench_mg_tuning.cpp
  • tests/test_mg_physics_match.cpp
  • tests/test_poisson_unified.cpp
  • tests/test_vcycle_graph_stress.cpp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@sbryngelson sbryngelson merged commit 565271b into master Jan 9, 2026
4 checks passed
@sbryngelson sbryngelson deleted the optimization branch January 9, 2026 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Development

Successfully merging this pull request may close these issues.

2 participants