Skip to content

Commit 4a58686

Browse files
Haoran JiangKernel Patches Daemon
authored andcommitted
LoongArch: BPF: Fix tailcall hierarchy
In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller。This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Fixes: bb035ef ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls") Fixes: 5dc6155 ("LoongArch: Add BPF JIT support") Signed-off-by: Haoran Jiang <[email protected]>
1 parent 1363106 commit 4a58686

File tree

1 file changed

+68
-44
lines changed

1 file changed

+68
-44
lines changed

arch/loongarch/net/bpf_jit.c

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
#include "bpf_jit.h"
88

99
#define REG_TCC LOONGARCH_GPR_A6
10-
#define TCC_SAVED LOONGARCH_GPR_S5
1110

12-
#define SAVE_RA BIT(0)
13-
#define SAVE_TCC BIT(1)
11+
#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
12+
1413

1514
static const int regmap[] = {
1615
/* return value from in-kernel function, and exit value for eBPF program */
@@ -32,32 +31,37 @@ static const int regmap[] = {
3231
[BPF_REG_AX] = LOONGARCH_GPR_T0,
3332
};
3433

35-
static void mark_call(struct jit_ctx *ctx)
34+
static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
3635
{
37-
ctx->flags |= SAVE_RA;
38-
}
36+
const struct bpf_prog *prog = ctx->prog;
37+
const bool is_main_prog = !bpf_is_subprog(prog);
3938

40-
static void mark_tail_call(struct jit_ctx *ctx)
41-
{
42-
ctx->flags |= SAVE_TCC;
43-
}
39+
if (is_main_prog) {
40+
emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
41+
*store_offset -= sizeof(long);
4442

45-
static bool seen_call(struct jit_ctx *ctx)
46-
{
47-
return (ctx->flags & SAVE_RA);
48-
}
43+
emit_tailcall_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
4944

50-
static bool seen_tail_call(struct jit_ctx *ctx)
51-
{
52-
return (ctx->flags & SAVE_TCC);
53-
}
45+
/* If REG_TCC < MAX_TAIL_CALL_CNT, push REG_TCC into stack */
46+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
5447

55-
static u8 tail_call_reg(struct jit_ctx *ctx)
56-
{
57-
if (seen_call(ctx))
58-
return TCC_SAVED;
48+
/* Calculate the pointer to REG_TCC in the stack and assign it to REG_TCC */
49+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
50+
51+
emit_uncond_jmp(ctx, 2);
52+
53+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
5954

60-
return REG_TCC;
55+
*store_offset -= sizeof(long);
56+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
57+
58+
} else {
59+
*store_offset -= sizeof(long);
60+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
61+
62+
*store_offset -= sizeof(long);
63+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
64+
}
6165
}
6266

6367
/*
@@ -80,6 +84,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
8084
* | $s4 |
8185
* +-------------------------+
8286
* | $s5 |
87+
* +-------------------------+
88+
* | reg_tcc |
89+
* +-------------------------+
90+
* | reg_tcc_ptr |
8391
* +-------------------------+ <--BPF_REG_FP
8492
* | prog->aux->stack_depth |
8593
* | (optional) |
@@ -89,21 +97,24 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
8997
static void build_prologue(struct jit_ctx *ctx)
9098
{
9199
int stack_adjust = 0, store_offset, bpf_stack_adjust;
100+
const struct bpf_prog *prog = ctx->prog;
101+
const bool is_main_prog = !bpf_is_subprog(prog);
92102

93103
bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94104

95-
/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96-
stack_adjust += sizeof(long) * 8;
105+
/* To store ra, fp, s0, s1, s2, s3, s4, s5, reg_tcc and reg_tcc_ptr */
106+
stack_adjust += sizeof(long) * 10;
97107

98108
stack_adjust = round_up(stack_adjust, 16);
99109
stack_adjust += bpf_stack_adjust;
100110

101111
/*
102-
* First instruction initializes the tail call count (TCC).
103-
* On tail call we skip this instruction, and the TCC is
112+
* First instruction initializes the tail call count (TCC) register
113+
* to zero. On tail call we skip this instruction, and the TCC is
104114
* passed in REG_TCC from the caller.
105115
*/
106-
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
116+
if (is_main_prog)
117+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
107118

108119
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109120

@@ -131,20 +142,13 @@ static void build_prologue(struct jit_ctx *ctx)
131142
store_offset -= sizeof(long);
132143
emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133144

145+
prepare_bpf_tail_call_cnt(ctx, &store_offset);
146+
134147
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135148

136149
if (bpf_stack_adjust)
137150
emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138151

139-
/*
140-
* Program contains calls and tail calls, so REG_TCC need
141-
* to be saved across calls.
142-
*/
143-
if (seen_tail_call(ctx) && seen_call(ctx))
144-
move_reg(ctx, TCC_SAVED, REG_TCC);
145-
else
146-
emit_insn(ctx, nop);
147-
148152
ctx->stack_size = stack_adjust;
149153
}
150154

@@ -177,6 +181,17 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
177181
load_offset -= sizeof(long);
178182
emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
179183

184+
/*
185+
* When push into the stack, follow the order of tcc then tcc_ptr.
186+
* When pop from the stack, first pop tcc_ptr followed by tcc
187+
*/
188+
load_offset -= 2*sizeof(long);
189+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
190+
191+
/* pop tcc_ptr to REG_TCC */
192+
load_offset += sizeof(long);
193+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
194+
180195
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
181196

182197
if (!is_tail_call) {
@@ -211,7 +226,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
211226
static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
212227
{
213228
int off;
214-
u8 tcc = tail_call_reg(ctx);
229+
int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
215230
u8 a1 = LOONGARCH_GPR_A1;
216231
u8 a2 = LOONGARCH_GPR_A2;
217232
u8 t1 = LOONGARCH_GPR_T1;
@@ -240,11 +255,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
240255
goto toofar;
241256

242257
/*
243-
* if (--TCC < 0)
244-
* goto out;
258+
* if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
259+
* goto out;
245260
*/
246-
emit_insn(ctx, addid, REG_TCC, tcc, -1);
247-
if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
261+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
262+
emit_insn(ctx, ldd, t3, REG_TCC, 0);
263+
emit_insn(ctx, addid, t3, t3, 1);
264+
emit_insn(ctx, std, t3, REG_TCC, 0);
265+
emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
266+
if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
248267
goto toofar;
249268

250269
/*
@@ -465,6 +484,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
465484
const s16 off = insn->off;
466485
const s32 imm = insn->imm;
467486
const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
487+
int tcc_ptr_off;
468488

469489
switch (code) {
470490
/* dst = src */
@@ -891,12 +911,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
891911

892912
/* function call */
893913
case BPF_JMP | BPF_CALL:
894-
mark_call(ctx);
895914
ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
896915
&func_addr, &func_addr_fixed);
897916
if (ret < 0)
898917
return ret;
899918

919+
if (insn->src_reg == BPF_PSEUDO_CALL) {
920+
tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
921+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
922+
}
923+
924+
900925
move_addr(ctx, t1, func_addr);
901926
emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
902927

@@ -907,7 +932,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
907932

908933
/* tail call */
909934
case BPF_JMP | BPF_TAIL_CALL:
910-
mark_tail_call(ctx);
911935
if (emit_bpf_tail_call(ctx, i) < 0)
912936
return -EINVAL;
913937
break;

0 commit comments

Comments
 (0)