Skip to content

Autodiff flags #139700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,10 @@ fn thin_lto(
}
}

fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
for &val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintModBefore => {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
}
Expand All @@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
llvm::set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(false);
llvm::set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModBefore => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModAfter => {}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModFinal => {}
// This is required and already checked
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
}
}
// This helps with handling enums for now.
Expand Down Expand Up @@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let stage =
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
let stage = if thin {
write::AutodiffStage::PreAD
} else {
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
};

if enable_ad {
enable_autodiff_settings(&config.autodiff, module);
enable_autodiff_settings(&config.autodiff);
}

unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

if cfg!(llvm_enzyme) && enable_ad {
// This is the post-autodiff IR, mainly used for testing and educational purposes.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}
}

// This is the final IR, so people should be able to inspect the optimized autodiff output,
Expand Down
16 changes: 15 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,20 +565,31 @@ pub(crate) unsafe fn llvm_optimize(

let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
let merge_functions;
let unroll_loops;
let vectorize_slp;
let vectorize_loop;

// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
// We therefore have two calls to llvm_optimize, if autodiff is used.
//
// We also must disable merge_functions, since autodiff placeholder/dummy bodies tend to be
// identical. We run opts before AD, so there is a chance that LLVM will merge our dummies.
// In that case, we lack some dummy bodies and can't replace them with the real AD code anymore.
// We then would need to abort compilation. This was especially common in test cases.
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
merge_functions = false;
unroll_loops = false;
vectorize_slp = false;
vectorize_loop = false;
} else {
unroll_loops =
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
merge_functions = config.merge_functions;
vectorize_slp = config.vectorize_slp;
vectorize_loop = config.vectorize_loop;
}
Expand Down Expand Up @@ -656,13 +667,16 @@ pub(crate) unsafe fn llvm_optimize(
thin_lto_buffer,
config.emit_thin_lto,
config.emit_thin_lto_summary,
config.merge_functions,
merge_functions,
unroll_loops,
vectorize_slp,
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
print_before_enzyme,
print_after_enzyme,
print_passes,
sanitizer_options.as_ref(),
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ pub(crate) fn differentiate<'ll>(
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}

// Before dumping the module, we want all the TypeTrees to become part of the module.
// Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
for item in diff_items.iter() {
let name = item.source.clone();
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,9 @@ unsafe extern "C" {
DisableSimplifyLibCalls: bool,
EmitLifetimeMarkers: bool,
RunEnzyme: bool,
PrintBeforeEnzyme: bool,
PrintAfterEnzyme: bool,
PrintPasses: bool,
SanitizerOptions: Option<&SanitizerOptions>,
PGOGenPath: *const c_char,
PGOUsePath: *const c_char,
Expand Down
30 changes: 28 additions & 2 deletions compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRPrinter/IRPrintingPasses.h"
#include "llvm/LTO/LTO.h"
#include "llvm/MC/MCSubtargetInfo.h"
#include "llvm/MC/TargetRegistry.h"
Expand Down Expand Up @@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
bool EmitLifetimeMarkers, bool RunEnzyme,
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
bool PrintAfterEnzyme, bool PrintPasses,
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
const char *PGOUsePath, bool InstrumentCoverage,
const char *InstrProfileOutput, const char *PGOSampleUsePath,
Expand Down Expand Up @@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
// now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) {
registerEnzymeAndPassPipeline(PB, true);

if (PrintBeforeEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
std::string Banner = "Module before EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}

registerEnzymeAndPassPipeline(PB, false);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}

if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
}
#endif
if (PrintPasses) {
// Print all passes from the PM:
std::string Pipeline;
raw_string_ostream SOS(Pipeline);
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
auto PassName = PIC.getPassNameForClassName(ClassName);
return PassName.empty() ? ClassName : PassName;
});
outs() << Pipeline;
outs() << "\n";
}

// Upgrade all calls to old intrinsics first.
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ pub enum AutoDiff {
/// Print the module after running autodiff and optimizations.
PrintModFinal,

/// Print all passes scheduled by LLVM
PrintPasses,
/// Disable extra opt run after running autodiff
NoPostopt,
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
LooseTypes,
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
Expand Down Expand Up @@ -1360,6 +1360,8 @@ pub mod parse {
"PrintModBefore" => AutoDiff::PrintModBefore,
"PrintModAfter" => AutoDiff::PrintModAfter,
"PrintModFinal" => AutoDiff::PrintModFinal,
"NoPostopt" => AutoDiff::NoPostopt,
"PrintPasses" => AutoDiff::PrintPasses,
"LooseTypes" => AutoDiff::LooseTypes,
"Inline" => AutoDiff::Inline,
_ => {
Expand Down Expand Up @@ -2095,6 +2097,8 @@ options! {
`=PrintModBefore`
`=PrintModAfter`
`=PrintModFinal`
`=PrintPasses`,
`=NoPostopt`
`=LooseTypes`
`=Inline`
Multiple options can be combined with commas."),
Expand Down
45 changes: 45 additions & 0 deletions tests/codegen/autodiff/identical_fnc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
// level. If a user tries to differentiate two identical functions within the same compilation unit,
// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
// We also explicetly test that we keep running merge_function after AD, by checking for two
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}

#[autodiff(d_square2, Reverse, Duplicated, Active)]
fn square2(x: &f64) -> f64 {
x * x
}

// CHECK:; identical_fnc::main
// CHECK-NEXT:; Function Attrs:
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
// CHECK-NEXT:start:
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)

fn main() {
let x = std::hint::black_box(3.0);
let mut dx1 = std::hint::black_box(1.0);
let mut dx2 = std::hint::black_box(1.0);
let _ = d_square(&x, &mut dx1, 1.0);
let _ = d_square2(&x, &mut dx2, 1.0);
assert_eq!(dx1, 6.0);
assert_eq!(dx2, 6.0);
}
Loading