Skip to content

Commit 57ac500

Browse files
committed
Fix opaque pointer core dump
1 parent b5398b0 commit 57ac500

File tree

6 files changed

+20
-7
lines changed

6 files changed

+20
-7
lines changed

Cargo.lock

+1
Original file line numberDiff line numberDiff line change
@@ -3897,6 +3897,7 @@ dependencies = [
38973897
"rustc_middle",
38983898
"rustc_session",
38993899
"rustc_span",
3900+
"rustc_symbol_mangling",
39003901
"rustc_target",
39013902
"serde",
39023903
"serde_json",

compiler/rustc_codegen_llvm/src/back/write.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use llvm::{
2323
LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction,
2424
LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext,
2525
LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf,
26-
LLVMVoidTypeInContext,
26+
LLVMVoidTypeInContext, LLVMGlobalGetValueType,
2727
};
2828
//use llvm::LLVMRustGetNamedValue;
2929
use rustc_codegen_ssa::back::link::ensure_removed;
@@ -694,8 +694,11 @@ pub(crate) unsafe fn enzyme_ad(
694694
),
695695
_ => unreachable!(),
696696
};
697-
let f_type = LLVMTypeOf(res);
698-
let f_return_type = LLVMGetReturnType(LLVMGetElementType(f_type));
697+
//let f_type = LLVMTypeOf(res);
698+
699+
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
700+
701+
//let f_return_type = LLVMGetReturnType(LLVMGetElementType(f_type));
699702
let void_type = LLVMVoidTypeInContext(llcx);
700703
if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type {
701704
//dbg!("Reverse Mode sanitizer");

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
11331133
1, // vector mode width
11341134
1, // free memory
11351135
Option::None,
1136+
0, // do not force anonymous tape
11361137
dummy_type, // additional_arg, type info (return + args)
11371138
args_uncacheable.as_ptr(),
11381139
args_uncacheable.len(), // uncacheable arguments
@@ -1168,7 +1169,9 @@ extern "C" {
11681169
pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
11691170
pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>;
11701171
pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
1172+
pub fn LLVMGlobalGetValueType(val: &Value) -> &Type;
11711173

1174+
pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type;
11721175
pub fn LLVMRustInstallFatalErrorHandler();
11731176
pub fn LLVMRustDisableSystemDialogsOnCrash();
11741177

@@ -2813,6 +2816,7 @@ extern "C" {
28132816
width: ::std::os::raw::c_uint,
28142817
freeMemory: u8,
28152818
additionalArg: Option<&Type>,
2819+
forceAnonymousTape: u8,
28162820
typeInfo: CFnTypeInfo,
28172821
_uncacheable_args: *const u8,
28182822
uncacheable_args_size: size_t,

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ extern "C" char *LLVMRustGetLastError(void) {
9090
return Ret;
9191
}
9292

93+
extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) {
94+
auto Ftype = unwrap<Function>(Fn)->getFunctionType();
95+
return wrap(Ftype);
96+
}
97+
9398
// Enzyme
9499
// extern "C" bool LLVMRustIsNull(LLVMValueRef V) {
95100
// Value *Val = unwrap(V);

library/autodiff/examples/sin.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#![feature(bench_black_box)]
21
use autodiff::autodiff;
32

43
#[autodiff(cos_inplace, Reverse, Const)]

library/autodiff/examples/vec.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
#![feature(bench_black_box)]
22
use autodiff::autodiff;
33

4-
#[autodiff(d_sum, Reverse, Active)]
4+
#[autodiff(d_sum, Forward, Duplicated)]
55
fn sum(#[dup] x: &Vec<Vec<f32>>) -> f32 {
66
x.into_iter().map(|x| x.into_iter().map(|x| x.sqrt())).flatten().sum()
77
}
88

99
fn main() {
1010
let a = vec![vec![1.0, 2.0, 4.0, 8.0]];
11-
let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]];
11+
//let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]];
12+
let b = vec![vec![1.0, 0.0, 0.0, 0.0]];
1213

13-
d_sum(&a, &mut b, 1.0);
14+
dbg!(&d_sum(&a, &b));
1415

1516
dbg!(&b);
1617
}

0 commit comments

Comments
 (0)