Skip to content

Commit dcd3512

Browse files
committed
linker/inline: added test for cascade inlining.
1 parent 6e4d191 commit dcd3512

File tree

2 files changed

+118
-3
lines changed

2 files changed

+118
-3
lines changed

Diff for: crates/rustc_codegen_spirv/src/linker/inline.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,11 @@ impl Inliner<'_, '_> {
386386
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
387387
// If any of the OpReturns were invalid, return will also be invalid.
388388
for value in &return_values {
389-
if self.invalid_args.contains(value) {
389+
let value_rewritten = *rewrite_rules.get(value).unwrap_or(value);
390+
// value_rewritten might be originally a function argument
391+
if self.invalid_args.contains(value) || self.invalid_args.contains(&value_rewritten) {
390392
self.invalid_args.insert(call_result_id);
391-
self.invalid_args
392-
.insert(*rewrite_rules.get(value).unwrap_or(value));
393+
self.invalid_args.insert(value_rewritten);
393394
}
394395
}
395396
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);

Diff for: crates/rustc_codegen_spirv/src/linker/test.rs

+114
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,117 @@ fn names_and_decorations() {
505505

506506
without_header_eq(result, expect);
507507
}
508+
509+
#[test]
510+
fn cascade_inlining_of_ptr_args() {
511+
let a = assemble_spirv(
512+
r#"OpCapability Linkage
513+
OpDecorate %1 LinkageAttributes "foo" Export
514+
%2 = OpTypeInt 32 0
515+
%8 = OpConstant %2 0
516+
%3 = OpTypeStruct %2 %2
517+
%4 = OpTypePointer Function %2
518+
%5 = OpTypePointer Function %3
519+
%6 = OpTypeFunction %4 %5
520+
%1 = OpFunction %4 Const %6
521+
%7 = OpFunctionParameter %5
522+
%10 = OpLabel
523+
%9 = OpAccessChain %4 %7 %8
524+
OpReturnValue %9
525+
OpFunctionEnd
526+
"#,
527+
);
528+
529+
let b = assemble_spirv(
530+
r#"OpCapability Linkage
531+
OpDecorate %1 LinkageAttributes "bar" Export
532+
%2 = OpTypeInt 32 0
533+
%4 = OpTypePointer Function %2
534+
%6 = OpTypeFunction %2 %4
535+
%1 = OpFunction %2 None %6
536+
%7 = OpFunctionParameter %4
537+
%10 = OpLabel
538+
%8 = OpLoad %2 %7
539+
OpReturnValue %8
540+
OpFunctionEnd
541+
"#,
542+
);
543+
544+
let c = assemble_spirv(
545+
r#"OpCapability Linkage
546+
OpDecorate %1 LinkageAttributes "baz" Export
547+
%2 = OpTypeInt 32 0
548+
%4 = OpTypePointer Function %2
549+
%6 = OpTypeFunction %4 %4
550+
%1 = OpFunction %4 None %6
551+
%7 = OpFunctionParameter %4
552+
%10 = OpLabel
553+
OpReturnValue %7
554+
OpFunctionEnd
555+
"#,
556+
);
557+
558+
// In here, inlining foo should mark its return result as a not-fit-for-function-consumption
559+
// pointer and inline "baz" as well. That would lead to inlining "bar" too.
560+
let d = assemble_spirv(
561+
r#"OpCapability Linkage
562+
OpDecorate %10 LinkageAttributes "foo" Import
563+
OpDecorate %12 LinkageAttributes "bar" Import
564+
OpDecorate %14 LinkageAttributes "baz" Import
565+
OpName %1 "main"
566+
%2 = OpTypeInt 32 0
567+
%3 = OpTypeStruct %2 %2
568+
%4 = OpTypePointer Function %2
569+
%5 = OpTypePointer Function %3
570+
%6 = OpTypeFunction %4 %5
571+
%7 = OpTypeFunction %2 %4
572+
%8 = OpTypeFunction %4 %4
573+
%10 = OpFunction %4 Const %6
574+
%11 = OpFunctionParameter %5
575+
OpFunctionEnd
576+
%12 = OpFunction %2 None %7
577+
%13 = OpFunctionParameter %4
578+
OpFunctionEnd
579+
%14 = OpFunction %4 None %8
580+
%15 = OpFunctionParameter %4
581+
OpFunctionEnd
582+
%21 = OpTypeFunction %2 %5
583+
%1 = OpFunction %2 None %14
584+
%22 = OpFunctionParameter %5
585+
%23 = OpLabel
586+
%24 = OpFunctionCall %4 %10 %22
587+
%25 = OpFunctionCall %4 %14 %24
588+
%26 = OpFunctionCall %2 %12 %25
589+
OpReturnValue %26
590+
OpFunctionEnd
591+
"#,
592+
);
593+
594+
let result = assemble_and_link(&[&a, &b, &c, &d]).unwrap();
595+
let expect = r#"OpName %1 "main"
596+
%2 = OpTypeInt 32 0
597+
%3 = OpConstant %2 0
598+
%4 = OpTypeStruct %2 %2
599+
%5 = OpTypePointer Function %2
600+
%6 = OpTypePointer Function %4
601+
%7 = OpTypeFunction %5 %6
602+
%8 = OpTypeFunction %2 %5
603+
%9 = OpTypeFunction %5 %5
604+
%10 = OpTypeFunction %2 %6
605+
%11 = OpTypePointer Function %5
606+
%12 = OpFunction %2 None %8
607+
%13 = OpFunctionParameter %5
608+
%14 = OpLabel
609+
%15 = OpLoad %2 %13
610+
OpReturnValue %15
611+
OpFunctionEnd
612+
%1 = OpFunction %2 None %16
613+
%17 = OpFunctionParameter %6
614+
%18 = OpLabel
615+
%19 = OpAccessChain %5 %17 %3
616+
%20 = OpLoad %2 %19
617+
OpReturnValue %20
618+
OpFunctionEnd"#;
619+
620+
without_header_eq(result, expect);
621+
}

0 commit comments

Comments
 (0)