Skip to content

Commit 0e2ea17

Browse files
Haoran JiangKernel Patches Daemon
authored andcommitted
LoongArch: BPF: Fix tailcall hierarchy
In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller。This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Signed-off-by: Haoran Jiang <[email protected]>
1 parent cc40b36 commit 0e2ea17

File tree

1 file changed

+71
-41
lines changed

1 file changed

+71
-41
lines changed

arch/loongarch/net/bpf_jit.c

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#define SAVE_RA BIT(0)
1313
#define SAVE_TCC BIT(1)
1414

15+
#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
16+
17+
1518
static const int regmap[] = {
1619
/* return value from in-kernel function, and exit value for eBPF program */
1720
[BPF_REG_0] = LOONGARCH_GPR_A5,
@@ -32,32 +35,37 @@ static const int regmap[] = {
3235
[BPF_REG_AX] = LOONGARCH_GPR_T0,
3336
};
3437

35-
static void mark_call(struct jit_ctx *ctx)
38+
static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
3639
{
37-
ctx->flags |= SAVE_RA;
38-
}
40+
const struct bpf_prog *prog = ctx->prog;
41+
const bool is_main_prog = !bpf_is_subprog(prog);
3942

40-
static void mark_tail_call(struct jit_ctx *ctx)
41-
{
42-
ctx->flags |= SAVE_TCC;
43-
}
43+
if (is_main_prog) {
44+
emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
45+
*store_offset -= sizeof(long);
4446

45-
static bool seen_call(struct jit_ctx *ctx)
46-
{
47-
return (ctx->flags & SAVE_RA);
48-
}
47+
emit_tailcall_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
4948

50-
static bool seen_tail_call(struct jit_ctx *ctx)
51-
{
52-
return (ctx->flags & SAVE_TCC);
53-
}
49+
/* If REG_TCC < MAX_TAIL_CALL_CNT, push REG_TCC into stack */
50+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
5451

55-
static u8 tail_call_reg(struct jit_ctx *ctx)
56-
{
57-
if (seen_call(ctx))
58-
return TCC_SAVED;
52+
/* Calculate the pointer to REG_TCC in the stack and assign it to REG_TCC */
53+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
54+
55+
emit_uncond_jmp(ctx, 2);
56+
57+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
5958

60-
return REG_TCC;
59+
*store_offset -= sizeof(long);
60+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
61+
62+
} else {
63+
*store_offset -= sizeof(long);
64+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
65+
66+
*store_offset -= sizeof(long);
67+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
68+
}
6169
}
6270

6371
/*
@@ -80,6 +88,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
8088
* | $s4 |
8189
* +-------------------------+
8290
* | $s5 |
91+
* +-------------------------+
92+
* | reg_tcc |
93+
* +-------------------------+
94+
* | reg_tcc_ptr |
8395
* +-------------------------+ <--BPF_REG_FP
8496
* | prog->aux->stack_depth |
8597
* | (optional) |
@@ -89,21 +101,24 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
89101
static void build_prologue(struct jit_ctx *ctx)
90102
{
91103
int stack_adjust = 0, store_offset, bpf_stack_adjust;
104+
const struct bpf_prog *prog = ctx->prog;
105+
const bool is_main_prog = !bpf_is_subprog(prog);
92106

93107
bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94108

95-
/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96-
stack_adjust += sizeof(long) * 8;
109+
/* To store ra, fp, s0, s1, s2, s3, s4, s5, reg_tcc and reg_tcc_ptr */
110+
stack_adjust += sizeof(long) * 10;
97111

98112
stack_adjust = round_up(stack_adjust, 16);
99113
stack_adjust += bpf_stack_adjust;
100114

101115
/*
102-
* First instruction initializes the tail call count (TCC).
103-
* On tail call we skip this instruction, and the TCC is
116+
* First instruction initializes the tail call count (TCC) register
117+
* to zero. On tail call we skip this instruction, and the TCC is
104118
* passed in REG_TCC from the caller.
105119
*/
106-
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
120+
if (is_main_prog)
121+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
107122

108123
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109124

@@ -131,20 +146,13 @@ static void build_prologue(struct jit_ctx *ctx)
131146
store_offset -= sizeof(long);
132147
emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133148

149+
prepare_bpf_tail_call_cnt(ctx, &store_offset);
150+
134151
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135152

136153
if (bpf_stack_adjust)
137154
emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138155

139-
/*
140-
* Program contains calls and tail calls, so REG_TCC need
141-
* to be saved across calls.
142-
*/
143-
if (seen_tail_call(ctx) && seen_call(ctx))
144-
move_reg(ctx, TCC_SAVED, REG_TCC);
145-
else
146-
emit_insn(ctx, nop);
147-
148156
ctx->stack_size = stack_adjust;
149157
}
150158

@@ -177,6 +185,17 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
177185
load_offset -= sizeof(long);
178186
emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
179187

188+
/*
189+
* When push into the stack, follow the order of tcc then tcc_ptr.
190+
* When pop from the stack, first pop tcc_ptr followed by tcc
191+
*/
192+
load_offset -= 2*sizeof(long);
193+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
194+
195+
/* pop tcc_ptr to REG_TCC */
196+
load_offset += sizeof(long);
197+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
198+
180199
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
181200

182201
if (!is_tail_call) {
@@ -211,7 +230,8 @@ bool bpf_jit_supports_far_kfunc_call(void)
211230
static int emit_bpf_tail_call(int insn, struct jit_ctx *ctx)
212231
{
213232
int off;
214-
u8 tcc = tail_call_reg(ctx);
233+
int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
234+
215235
u8 a1 = LOONGARCH_GPR_A1;
216236
u8 a2 = LOONGARCH_GPR_A2;
217237
u8 t1 = LOONGARCH_GPR_T1;
@@ -239,12 +259,17 @@ static int emit_bpf_tail_call(int insn, struct jit_ctx *ctx)
239259
goto toofar;
240260

241261
/*
242-
* if (--TCC < 0)
243-
* goto out;
262+
* if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
263+
* goto out;
244264
*/
245-
emit_insn(ctx, addid, REG_TCC, tcc, -1);
265+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
266+
emit_insn(ctx, ldd, t3, REG_TCC, 0);
267+
emit_insn(ctx, addid, t3, t3, 1);
268+
emit_insn(ctx, std, t3, REG_TCC, 0);
269+
emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
270+
246271
jmp_offset = tc_ninsn - (ctx->idx - idx0);
247-
if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
272+
if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
248273
goto toofar;
249274

250275
/*
@@ -464,6 +489,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
464489
const s16 off = insn->off;
465490
const s32 imm = insn->imm;
466491
const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
492+
int tcc_ptr_off;
467493

468494
switch (code) {
469495
/* dst = src */
@@ -890,12 +916,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
890916

891917
/* function call */
892918
case BPF_JMP | BPF_CALL:
893-
mark_call(ctx);
894919
ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
895920
&func_addr, &func_addr_fixed);
896921
if (ret < 0)
897922
return ret;
898923

924+
if (insn->src_reg == BPF_PSEUDO_CALL) {
925+
tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
926+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
927+
}
928+
929+
899930
move_addr(ctx, t1, func_addr);
900931
emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
901932

@@ -906,7 +937,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
906937

907938
/* tail call */
908939
case BPF_JMP | BPF_TAIL_CALL:
909-
mark_tail_call(ctx);
910940
if (emit_bpf_tail_call(i, ctx) < 0)
911941
return -EINVAL;
912942
break;

0 commit comments

Comments
 (0)