Skip to content

Commit 1184458

Browse files
sunfishcodeaheejin
andauthored
[WebAssembly] Protect memory.fill and memory.copy from zero-length ranges. (llvm#112617)
WebAssembly's `memory.fill` and `memory.copy` instructions trap if the pointers are out of bounds, even if the length is zero. This is different from LLVM, which expects that it can call `memcpy` on arbitrary invalid pointers if the length is zero. To avoid spurious traps, branch around `memory.fill` and `memory.copy` when the length is zero. --------- Co-authored-by: Heejin Ahn <[email protected]>
1 parent 4c87793 commit 1184458

File tree

6 files changed

+404
-80
lines changed

6 files changed

+404
-80
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISD.def

+7-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ HANDLE_NODETYPE(PROMOTE_LOW)
4242
HANDLE_NODETYPE(TRUNC_SAT_ZERO_S)
4343
HANDLE_NODETYPE(TRUNC_SAT_ZERO_U)
4444
HANDLE_NODETYPE(DEMOTE_ZERO)
45-
HANDLE_NODETYPE(MEMORY_COPY)
46-
HANDLE_NODETYPE(MEMORY_FILL)
4745
HANDLE_NODETYPE(I64_ADD128)
4846
HANDLE_NODETYPE(I64_SUB128)
4947
HANDLE_NODETYPE(I64_MUL_WIDE_S)
@@ -54,3 +52,10 @@ HANDLE_MEM_NODETYPE(GLOBAL_GET)
5452
HANDLE_MEM_NODETYPE(GLOBAL_SET)
5553
HANDLE_MEM_NODETYPE(TABLE_GET)
5654
HANDLE_MEM_NODETYPE(TABLE_SET)
55+
56+
// Bulk memory instructions. These follow LLVM's expected semantics of
57+
// supporting out-of-bounds pointers if the length is zero, by inserting
58+
// a branch around Wasm's `memory.copy` and `memory.fill`, which would
59+
// otherwise trap.
60+
HANDLE_NODETYPE(MEMCPY)
61+
HANDLE_NODETYPE(MEMSET)

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

+140
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,138 @@ static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
568568
return DoneMBB;
569569
}
570570

571+
// Lower a `MEMCPY` instruction into a CFG triangle around a `MEMORY_COPY`
572+
// instuction to handle the zero-length case.
573+
static MachineBasicBlock *LowerMemcpy(MachineInstr &MI, DebugLoc DL,
574+
MachineBasicBlock *BB,
575+
const TargetInstrInfo &TII, bool Int64) {
576+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
577+
578+
MachineOperand DstMem = MI.getOperand(0);
579+
MachineOperand SrcMem = MI.getOperand(1);
580+
MachineOperand Dst = MI.getOperand(2);
581+
MachineOperand Src = MI.getOperand(3);
582+
MachineOperand Len = MI.getOperand(4);
583+
584+
// We're going to add an extra use to `Len` to test if it's zero; that
585+
// use shouldn't be a kill, even if the original use is.
586+
MachineOperand NoKillLen = Len;
587+
NoKillLen.setIsKill(false);
588+
589+
// Decide on which `MachineInstr` opcode we're going to use.
590+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
591+
unsigned MemoryCopy =
592+
Int64 ? WebAssembly::MEMORY_COPY_A64 : WebAssembly::MEMORY_COPY_A32;
593+
594+
// Create two new basic blocks; one for the new `memory.fill` that we can
595+
// branch over, and one for the rest of the instructions after the original
596+
// `memory.fill`.
597+
const BasicBlock *LLVMBB = BB->getBasicBlock();
598+
MachineFunction *F = BB->getParent();
599+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
600+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
601+
602+
MachineFunction::iterator It = ++BB->getIterator();
603+
F->insert(It, TrueMBB);
604+
F->insert(It, DoneMBB);
605+
606+
// Transfer the remainder of BB and its successor edges to DoneMBB.
607+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
608+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
609+
610+
// Connect the CFG edges.
611+
BB->addSuccessor(TrueMBB);
612+
BB->addSuccessor(DoneMBB);
613+
TrueMBB->addSuccessor(DoneMBB);
614+
615+
// Create a virtual register for the `Eqz` result.
616+
unsigned EqzReg;
617+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
618+
619+
// Erase the original `memory.copy`.
620+
MI.eraseFromParent();
621+
622+
// Test if `Len` is zero.
623+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
624+
625+
// Insert a new `memory.copy`.
626+
BuildMI(TrueMBB, DL, TII.get(MemoryCopy))
627+
.add(DstMem)
628+
.add(SrcMem)
629+
.add(Dst)
630+
.add(Src)
631+
.add(Len);
632+
633+
// Create the CFG triangle.
634+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
635+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
636+
637+
return DoneMBB;
638+
}
639+
640+
// Lower a `MEMSET` instruction into a CFG triangle around a `MEMORY_FILL`
641+
// instuction to handle the zero-length case.
642+
static MachineBasicBlock *LowerMemset(MachineInstr &MI, DebugLoc DL,
643+
MachineBasicBlock *BB,
644+
const TargetInstrInfo &TII, bool Int64) {
645+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
646+
647+
MachineOperand Mem = MI.getOperand(0);
648+
MachineOperand Dst = MI.getOperand(1);
649+
MachineOperand Val = MI.getOperand(2);
650+
MachineOperand Len = MI.getOperand(3);
651+
652+
// We're going to add an extra use to `Len` to test if it's zero; that
653+
// use shouldn't be a kill, even if the original use is.
654+
MachineOperand NoKillLen = Len;
655+
NoKillLen.setIsKill(false);
656+
657+
// Decide on which `MachineInstr` opcode we're going to use.
658+
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
659+
unsigned MemoryFill =
660+
Int64 ? WebAssembly::MEMORY_FILL_A64 : WebAssembly::MEMORY_FILL_A32;
661+
662+
// Create two new basic blocks; one for the new `memory.fill` that we can
663+
// branch over, and one for the rest of the instructions after the original
664+
// `memory.fill`.
665+
const BasicBlock *LLVMBB = BB->getBasicBlock();
666+
MachineFunction *F = BB->getParent();
667+
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
668+
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
669+
670+
MachineFunction::iterator It = ++BB->getIterator();
671+
F->insert(It, TrueMBB);
672+
F->insert(It, DoneMBB);
673+
674+
// Transfer the remainder of BB and its successor edges to DoneMBB.
675+
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
676+
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
677+
678+
// Connect the CFG edges.
679+
BB->addSuccessor(TrueMBB);
680+
BB->addSuccessor(DoneMBB);
681+
TrueMBB->addSuccessor(DoneMBB);
682+
683+
// Create a virtual register for the `Eqz` result.
684+
unsigned EqzReg;
685+
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
686+
687+
// Erase the original `memory.fill`.
688+
MI.eraseFromParent();
689+
690+
// Test if `Len` is zero.
691+
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);
692+
693+
// Insert a new `memory.copy`.
694+
BuildMI(TrueMBB, DL, TII.get(MemoryFill)).add(Mem).add(Dst).add(Val).add(Len);
695+
696+
// Create the CFG triangle.
697+
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
698+
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
699+
700+
return DoneMBB;
701+
}
702+
571703
static MachineBasicBlock *
572704
LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
573705
const WebAssemblySubtarget *Subtarget,
@@ -725,6 +857,14 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
725857
case WebAssembly::FP_TO_UINT_I64_F64:
726858
return LowerFPToInt(MI, DL, BB, TII, true, true, true,
727859
WebAssembly::I64_TRUNC_U_F64);
860+
case WebAssembly::MEMCPY_A32:
861+
return LowerMemcpy(MI, DL, BB, TII, false);
862+
case WebAssembly::MEMCPY_A64:
863+
return LowerMemcpy(MI, DL, BB, TII, true);
864+
case WebAssembly::MEMSET_A32:
865+
return LowerMemset(MI, DL, BB, TII, false);
866+
case WebAssembly::MEMSET_A64:
867+
return LowerMemset(MI, DL, BB, TII, true);
728868
case WebAssembly::CALL_RESULTS:
729869
case WebAssembly::RET_CALL_RESULTS:
730870
return LowerCallResults(MI, DL, BB, Subtarget, TII);

llvm/lib/Target/WebAssembly/WebAssemblyInstrBulkMemory.td

+54-19
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,31 @@ multiclass BULK_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
2121
}
2222

2323
// Bespoke types and nodes for bulk memory ops
24+
2425
def wasm_memcpy_t : SDTypeProfile<0, 5,
2526
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
2627
>;
27-
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memcpy_t,
28-
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
29-
3028
def wasm_memset_t : SDTypeProfile<0, 4,
3129
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
3230
>;
33-
def wasm_memset : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memset_t,
31+
32+
// memory.copy with a branch to avoid trapping in the case of out-of-bounds
33+
// pointers with empty ranges.
34+
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMCPY", wasm_memcpy_t,
35+
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;
36+
37+
// memory.fill with a branch to avoid trapping in the case of out-of-bounds
38+
// pointers with empty ranges.
39+
def wasm_memset : SDNode<"WebAssemblyISD::MEMSET", wasm_memset_t,
3440
[SDNPHasChain, SDNPMayStore]>;
3541

42+
// A multiclass for defining Wasm's raw bulk-memory `memory.*` instructions.
43+
// `memory.copy` and `memory.fill` have Wasm's behavior rather than
44+
// `memcpy`/`memset` behavior.
3645
multiclass BulkMemoryOps<WebAssemblyRegClass rc, string B> {
3746

3847
let mayStore = 1, hasSideEffects = 1 in
39-
defm MEMORY_INIT_A#B :
48+
defm INIT_A#B :
4049
BULK_I<(outs),
4150
(ins i32imm_op:$seg, i32imm_op:$idx, rc:$dest,
4251
I32:$offset, I32:$size),
@@ -45,31 +54,57 @@ defm MEMORY_INIT_A#B :
4554
"memory.init\t$seg, $idx, $dest, $offset, $size",
4655
"memory.init\t$seg, $idx", 0x08>;
4756

48-
let hasSideEffects = 1 in
49-
defm DATA_DROP :
50-
BULK_I<(outs), (ins i32imm_op:$seg), (outs), (ins i32imm_op:$seg),
51-
[],
52-
"data.drop\t$seg", "data.drop\t$seg", 0x09>;
53-
5457
let mayLoad = 1, mayStore = 1 in
55-
defm MEMORY_COPY_A#B :
58+
defm COPY_A#B :
5659
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
5760
rc:$dst, rc:$src, rc:$len),
5861
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
59-
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
60-
rc:$dst, rc:$src, rc:$len
61-
)],
62+
[],
6263
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
6364
"memory.copy\t$src_idx, $dst_idx", 0x0a>;
6465

6566
let mayStore = 1 in
66-
defm MEMORY_FILL_A#B :
67+
defm FILL_A#B :
6768
BULK_I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
6869
(outs), (ins i32imm_op:$idx),
69-
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
70+
[],
7071
"memory.fill\t$idx, $dst, $value, $size",
7172
"memory.fill\t$idx", 0x0b>;
7273
}
7374

74-
defm : BulkMemoryOps<I32, "32">;
75-
defm : BulkMemoryOps<I64, "64">;
75+
defm MEMORY_ : BulkMemoryOps<I32, "32">;
76+
defm MEMORY_ : BulkMemoryOps<I64, "64">;
77+
78+
// A multiclass for defining `memcpy`/`memset` pseudo instructions. These have
79+
// the behavior the rest of LLVM CodeGen expects, and we lower them into code
80+
// sequences that include the Wasm `memory.fill` and `memory.copy` instructions
81+
// using custom inserters, because they introduce new control flow.
82+
multiclass BulkMemOps<WebAssemblyRegClass rc, string B> {
83+
84+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
85+
defm CPY_A#B : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
86+
rc:$dst, rc:$src, rc:$len),
87+
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
88+
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
89+
rc:$dst, rc:$src, rc:$len
90+
)],
91+
"", "", 0>,
92+
Requires<[HasBulkMemory]>;
93+
94+
let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
95+
defm SET_A#B : I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
96+
(outs), (ins i32imm_op:$idx),
97+
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
98+
"", "", 0>,
99+
Requires<[HasBulkMemory]>;
100+
101+
}
102+
103+
defm MEM : BulkMemOps<I32, "32">;
104+
defm MEM : BulkMemOps<I64, "64">;
105+
106+
let hasSideEffects = 1 in
107+
defm DATA_DROP :
108+
BULK_I<(outs), (ins i32imm_op:$seg), (outs), (ins i32imm_op:$seg),
109+
[],
110+
"data.drop\t$seg", "data.drop\t$seg", 0x09>;

llvm/lib/Target/WebAssembly/WebAssemblySelectionDAGInfo.cpp

+14-5
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemcpy(
2828

2929
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
3030
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
31-
return DAG.getNode(WebAssemblyISD::MEMORY_COPY, DL, MVT::Other,
32-
{Chain, MemIdx, MemIdx, Dst, Src,
33-
DAG.getZExtOrTrunc(Size, DL, LenMVT)});
31+
32+
// Use `MEMCPY` here instead of `MEMORY_COPY` because `memory.copy` traps
33+
// if the pointers are invalid even if the length is zero. `MEMCPY` gets
34+
// extra code to handle this in the way that LLVM IR expects.
35+
return DAG.getNode(
36+
WebAssemblyISD::MEMCPY, DL, MVT::Other,
37+
{Chain, MemIdx, MemIdx, Dst, Src, DAG.getZExtOrTrunc(Size, DL, LenMVT)});
3438
}
3539

3640
SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemmove(
@@ -52,8 +56,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemset(
5256

5357
SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
5458
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
59+
60+
// Use `MEMSET` here instead of `MEMORY_FILL` because `memory.fill` traps
61+
// if the pointers are invalid even if the length is zero. `MEMSET` gets
62+
// extra code to handle this in the way that LLVM IR expects.
63+
//
5564
// Only low byte matters for val argument, so anyext the i8
56-
return DAG.getNode(WebAssemblyISD::MEMORY_FILL, DL, MVT::Other, Chain, MemIdx,
57-
Dst, DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
65+
return DAG.getNode(WebAssemblyISD::MEMSET, DL, MVT::Other, Chain, MemIdx, Dst,
66+
DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
5867
DAG.getZExtOrTrunc(Size, DL, LenMVT));
5968
}

0 commit comments

Comments
 (0)