Skip to content

JAXBench: Add back jaxbench profiling and block size tuning scripts#52

Open
charleshong3 wants to merge 2 commits into
AI-Hypercomputer:mainfrom
charleshong3:add-jaxbench-device-profiler
Open

JAXBench: Add back jaxbench profiling and block size tuning scripts#52
charleshong3 wants to merge 2 commits into
AI-Hypercomputer:mainfrom
charleshong3:add-jaxbench-device-profiler

Conversation

@charleshong3

Copy link
Copy Markdown

These used to be in JAXBench but seem to have been removed during the migration to accelerator-agents. Not sure if we want them back.

charleshong3 and others added 2 commits June 16, 2026 11:23
Adds benchmark/profile_workload.py, a self-contained device-side timing
harness for JAXBench Pallas kernels. It JIT-compiles a workload variant
(baseline / optimized / custom), runs 5 warmup + 50 profiled iterations
under jax.profiler.trace(), parses the Perfetto JSON for per-iteration
on-device kernel times, and reports both device-profiler median_ms and a
wall_clock_median_ms cross-check, plus FLOPs/TFLOPs/utilization.

This is the authoritative device-side measurement method (vs host
wall-clock), suitable for apples-to-apples comparison against the
hand-tuned Pallas references. Only stdlib + jax/numpy deps.

Usage:
    PJRT_DEVICE=TPU python3 benchmark/profile_workload.py benchmark/8p_GEMM baseline
    PJRT_DEVICE=TPU python3 benchmark/profile_workload.py benchmark/8p_GEMM optimized

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds benchmark/tune_pallas.py, the companion grid-search tuner that
sweeps block-size configurations for each Pallas-optimized kernel
(e.g. GEMM block_shape/block_k, Paged pages_per_compute_block, Ragged
Paged Attention num_kv_pages_per_block/num_queries_per_block, Flash
block_q/block_k_major/block_k, Sparse block_q/block_kv), profiles each
config device-side, and reports the best. This is the tool that produced
the TUNED_PARAMS baked into the hand-tuned optimized.py references.

Self-contained (stdlib + jax/numpy), independent of profile_workload.py.

Usage:
    PJRT_DEVICE=TPU python3 benchmark/tune_pallas.py [workload_name]

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@google-cla

google-cla Bot commented Jun 16, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@charleshong3 charleshong3 changed the title Add back jaxbench profiling and block size tuning scripts JAXBench: Add back jaxbench profiling and block size tuning scripts Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant