7
7
#include "bpf_jit.h"
8
8
9
9
#define REG_TCC LOONGARCH_GPR_A6
10
- #define TCC_SAVED LOONGARCH_GPR_S5
11
10
12
- #define SAVE_RA BIT(0 )
13
- #define SAVE_TCC BIT(1)
11
+ #define BPF_TAIL_CALL_CNT_PTR_STACK_OFF ( stack ) (round_up(stack, 16) - 80 )
12
+
14
13
15
14
static const int regmap [] = {
16
15
/* return value from in-kernel function, and exit value for eBPF program */
@@ -32,32 +31,37 @@ static const int regmap[] = {
32
31
[BPF_REG_AX ] = LOONGARCH_GPR_T0 ,
33
32
};
34
33
35
- static void mark_call (struct jit_ctx * ctx )
34
+ static void prepare_bpf_tail_call_cnt (struct jit_ctx * ctx , int * store_offset )
36
35
{
37
- ctx -> flags |= SAVE_RA ;
38
- }
36
+ const struct bpf_prog * prog = ctx -> prog ;
37
+ const bool is_main_prog = ! bpf_is_subprog ( prog );
39
38
40
- static void mark_tail_call (struct jit_ctx * ctx )
41
- {
42
- ctx -> flags |= SAVE_TCC ;
43
- }
39
+ if (is_main_prog ) {
40
+ emit_insn (ctx , addid , LOONGARCH_GPR_T3 , LOONGARCH_GPR_ZERO , MAX_TAIL_CALL_CNT );
41
+ * store_offset -= sizeof (long );
44
42
45
- static bool seen_call (struct jit_ctx * ctx )
46
- {
47
- return (ctx -> flags & SAVE_RA );
48
- }
43
+ emit_tailcall_jmp (ctx , BPF_JGT , REG_TCC , LOONGARCH_GPR_T3 , 4 );
49
44
50
- static bool seen_tail_call (struct jit_ctx * ctx )
51
- {
52
- return (ctx -> flags & SAVE_TCC );
53
- }
45
+ /* If REG_TCC < MAX_TAIL_CALL_CNT, push REG_TCC into stack */
46
+ emit_insn (ctx , std , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
54
47
55
- static u8 tail_call_reg (struct jit_ctx * ctx )
56
- {
57
- if (seen_call (ctx ))
58
- return TCC_SAVED ;
48
+ /* Calculate the pointer to REG_TCC in the stack and assign it to REG_TCC */
49
+ emit_insn (ctx , addid , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
50
+
51
+ emit_uncond_jmp (ctx , 2 );
52
+
53
+ emit_insn (ctx , std , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
59
54
60
- return REG_TCC ;
55
+ * store_offset -= sizeof (long );
56
+ emit_insn (ctx , std , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
57
+
58
+ } else {
59
+ * store_offset -= sizeof (long );
60
+ emit_insn (ctx , std , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
61
+
62
+ * store_offset -= sizeof (long );
63
+ emit_insn (ctx , std , REG_TCC , LOONGARCH_GPR_SP , * store_offset );
64
+ }
61
65
}
62
66
63
67
/*
@@ -80,6 +84,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
80
84
* | $s4 |
81
85
* +-------------------------+
82
86
* | $s5 |
87
+ * +-------------------------+
88
+ * | reg_tcc |
89
+ * +-------------------------+
90
+ * | reg_tcc_ptr |
83
91
* +-------------------------+ <--BPF_REG_FP
84
92
* | prog->aux->stack_depth |
85
93
* | (optional) |
@@ -89,21 +97,24 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
89
97
static void build_prologue (struct jit_ctx * ctx )
90
98
{
91
99
int stack_adjust = 0 , store_offset , bpf_stack_adjust ;
100
+ const struct bpf_prog * prog = ctx -> prog ;
101
+ const bool is_main_prog = !bpf_is_subprog (prog );
92
102
93
103
bpf_stack_adjust = round_up (ctx -> prog -> aux -> stack_depth , 16 );
94
104
95
- /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96
- stack_adjust += sizeof (long ) * 8 ;
105
+ /* To store ra, fp, s0, s1, s2, s3, s4, s5, reg_tcc and reg_tcc_ptr */
106
+ stack_adjust += sizeof (long ) * 10 ;
97
107
98
108
stack_adjust = round_up (stack_adjust , 16 );
99
109
stack_adjust += bpf_stack_adjust ;
100
110
101
111
/*
102
- * First instruction initializes the tail call count (TCC).
103
- * On tail call we skip this instruction, and the TCC is
112
+ * First instruction initializes the tail call count (TCC) register
113
+ * to zero. On tail call we skip this instruction, and the TCC is
104
114
* passed in REG_TCC from the caller.
105
115
*/
106
- emit_insn (ctx , addid , REG_TCC , LOONGARCH_GPR_ZERO , MAX_TAIL_CALL_CNT );
116
+ if (is_main_prog )
117
+ emit_insn (ctx , addid , REG_TCC , LOONGARCH_GPR_ZERO , 0 );
107
118
108
119
emit_insn (ctx , addid , LOONGARCH_GPR_SP , LOONGARCH_GPR_SP , - stack_adjust );
109
120
@@ -131,20 +142,13 @@ static void build_prologue(struct jit_ctx *ctx)
131
142
store_offset -= sizeof (long );
132
143
emit_insn (ctx , std , LOONGARCH_GPR_S5 , LOONGARCH_GPR_SP , store_offset );
133
144
145
+ prepare_bpf_tail_call_cnt (ctx , & store_offset );
146
+
134
147
emit_insn (ctx , addid , LOONGARCH_GPR_FP , LOONGARCH_GPR_SP , stack_adjust );
135
148
136
149
if (bpf_stack_adjust )
137
150
emit_insn (ctx , addid , regmap [BPF_REG_FP ], LOONGARCH_GPR_SP , bpf_stack_adjust );
138
151
139
- /*
140
- * Program contains calls and tail calls, so REG_TCC need
141
- * to be saved across calls.
142
- */
143
- if (seen_tail_call (ctx ) && seen_call (ctx ))
144
- move_reg (ctx , TCC_SAVED , REG_TCC );
145
- else
146
- emit_insn (ctx , nop );
147
-
148
152
ctx -> stack_size = stack_adjust ;
149
153
}
150
154
@@ -177,6 +181,17 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
177
181
load_offset -= sizeof (long );
178
182
emit_insn (ctx , ldd , LOONGARCH_GPR_S5 , LOONGARCH_GPR_SP , load_offset );
179
183
184
+ /*
185
+ * When push into the stack, follow the order of tcc then tcc_ptr.
186
+ * When pop from the stack, first pop tcc_ptr followed by tcc
187
+ */
188
+ load_offset -= 2 * sizeof (long );
189
+ emit_insn (ctx , ldd , REG_TCC , LOONGARCH_GPR_SP , load_offset );
190
+
191
+ /* pop tcc_ptr to REG_TCC */
192
+ load_offset += sizeof (long );
193
+ emit_insn (ctx , ldd , REG_TCC , LOONGARCH_GPR_SP , load_offset );
194
+
180
195
emit_insn (ctx , addid , LOONGARCH_GPR_SP , LOONGARCH_GPR_SP , stack_adjust );
181
196
182
197
if (!is_tail_call ) {
@@ -211,7 +226,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
211
226
static int emit_bpf_tail_call (struct jit_ctx * ctx , int insn )
212
227
{
213
228
int off ;
214
- u8 tcc = tail_call_reg (ctx );
229
+ int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF (ctx -> stack_size );
215
230
u8 a1 = LOONGARCH_GPR_A1 ;
216
231
u8 a2 = LOONGARCH_GPR_A2 ;
217
232
u8 t1 = LOONGARCH_GPR_T1 ;
@@ -240,11 +255,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
240
255
goto toofar ;
241
256
242
257
/*
243
- * if (--TCC < 0 )
244
- * goto out;
258
+ * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT )
259
+ * goto out;
245
260
*/
246
- emit_insn (ctx , addid , REG_TCC , tcc , -1 );
247
- if (emit_tailcall_jmp (ctx , BPF_JSLT , REG_TCC , LOONGARCH_GPR_ZERO , jmp_offset ) < 0 )
261
+ emit_insn (ctx , ldd , REG_TCC , LOONGARCH_GPR_SP , tcc_ptr_off );
262
+ emit_insn (ctx , ldd , t3 , REG_TCC , 0 );
263
+ emit_insn (ctx , addid , t3 , t3 , 1 );
264
+ emit_insn (ctx , std , t3 , REG_TCC , 0 );
265
+ emit_insn (ctx , addid , t2 , LOONGARCH_GPR_ZERO , MAX_TAIL_CALL_CNT );
266
+ if (emit_tailcall_jmp (ctx , BPF_JSGT , t3 , t2 , jmp_offset ) < 0 )
248
267
goto toofar ;
249
268
250
269
/*
@@ -465,6 +484,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
465
484
const s16 off = insn -> off ;
466
485
const s32 imm = insn -> imm ;
467
486
const bool is32 = BPF_CLASS (insn -> code ) == BPF_ALU || BPF_CLASS (insn -> code ) == BPF_JMP32 ;
487
+ int tcc_ptr_off ;
468
488
469
489
switch (code ) {
470
490
/* dst = src */
@@ -891,12 +911,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
891
911
892
912
/* function call */
893
913
case BPF_JMP | BPF_CALL :
894
- mark_call (ctx );
895
914
ret = bpf_jit_get_func_addr (ctx -> prog , insn , extra_pass ,
896
915
& func_addr , & func_addr_fixed );
897
916
if (ret < 0 )
898
917
return ret ;
899
918
919
+ if (insn -> src_reg == BPF_PSEUDO_CALL ) {
920
+ tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF (ctx -> stack_size );
921
+ emit_insn (ctx , ldd , REG_TCC , LOONGARCH_GPR_SP , tcc_ptr_off );
922
+ }
923
+
924
+
900
925
move_addr (ctx , t1 , func_addr );
901
926
emit_insn (ctx , jirl , LOONGARCH_GPR_RA , t1 , 0 );
902
927
@@ -907,7 +932,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
907
932
908
933
/* tail call */
909
934
case BPF_JMP | BPF_TAIL_CALL :
910
- mark_tail_call (ctx );
911
935
if (emit_bpf_tail_call (ctx , i ) < 0 )
912
936
return - EINVAL ;
913
937
break ;
0 commit comments