diff --git a/src/ebpf_c_codegen.ml b/src/ebpf_c_codegen.ml index 3f85ff0..ab3148e 100644 --- a/src/ebpf_c_codegen.ml +++ b/src/ebpf_c_codegen.ml @@ -163,6 +163,16 @@ let create_c_context () = { dynptr_backed_pointers = Hashtbl.create 32; } +(** Get the appropriate fallback return value when bpf_tail_call() fails. + bpf_tail_call() is not guaranteed to succeed; when it fails execution + continues past the call site. Every arm that uses a tail call must have + an explicit return so the eBPF verifier can confirm all paths exit. *) +let get_tail_call_fallback_return ctx = + match ctx.current_function_context_type with + | Some "xdp" -> "XDP_PASS" + | Some "tc" -> "TC_ACT_OK" + | _ -> "0" + (** Helper functions for code generation *) (** Calculate the size of a type for dynptr field assignment operations. @@ -2200,13 +2210,19 @@ and generate_c_instruction ctx ir_instr = let args_str = String.concat ", " (List.map (generate_c_value ctx) args) in emit_line ctx (sprintf "/* Tail call to %s */" func_name); emit_line ctx (sprintf "bpf_tail_call(ctx, &prog_array, 0); /* %s(%s) */" func_name args_str); - emit_line ctx "/* If tail call fails, continue execution */" + (* Fallback return: bpf_tail_call() may fail; verifier requires all + branches to have an explicit return. *) + let fallback = get_tail_call_fallback_return ctx in + emit_line ctx (sprintf "return %s; /* tail call fallback */" fallback) | IRReturnTailCall (func_name, args, index) -> (* Generate explicit tail call *) let args_str = String.concat ", " (List.map (generate_c_value ctx) args) in emit_line ctx (sprintf "/* Tail call to %s (index %d) */" func_name index); emit_line ctx (sprintf "bpf_tail_call(ctx, &prog_array, %d); /* %s(%s) */" index func_name args_str); - emit_line ctx "/* If tail call fails, continue execution */"); + (* Fallback return: bpf_tail_call() may fail; verifier requires all + branches to have an explicit return. *) + let fallback = get_tail_call_fallback_return ctx in + emit_line ctx (sprintf "return %s; /* tail call fallback */" fallback)); decrease_indent ctx | IRDefaultPattern -> @@ -2223,13 +2239,19 @@ and generate_c_instruction ctx ir_instr = let args_str = String.concat ", " (List.map (generate_c_value ctx) args) in emit_line ctx (sprintf "/* Tail call to %s */" func_name); emit_line ctx (sprintf "bpf_tail_call(ctx, &prog_array, 0); /* %s(%s) */" func_name args_str); - emit_line ctx "/* If tail call fails, continue execution */" + (* Fallback return: bpf_tail_call() may fail; verifier requires all + branches to have an explicit return. *) + let fallback = get_tail_call_fallback_return ctx in + emit_line ctx (sprintf "return %s; /* tail call fallback */" fallback) | IRReturnTailCall (func_name, args, index) -> (* Generate explicit tail call *) let args_str = String.concat ", " (List.map (generate_c_value ctx) args) in emit_line ctx (sprintf "/* Tail call to %s (index %d) */" func_name index); emit_line ctx (sprintf "bpf_tail_call(ctx, &prog_array, %d); /* %s(%s) */" index func_name args_str); - emit_line ctx "/* If tail call fails, continue execution */"); + (* Fallback return: bpf_tail_call() may fail; verifier requires all + branches to have an explicit return. *) + let fallback = get_tail_call_fallback_return ctx in + emit_line ctx (sprintf "return %s; /* tail call fallback */" fallback)); decrease_indent ctx; emit_line ctx "}" diff --git a/tests/test_ebpf_c_codegen.ml b/tests/test_ebpf_c_codegen.ml index bb2e66c..d4df041 100644 --- a/tests/test_ebpf_c_codegen.ml +++ b/tests/test_ebpf_c_codegen.ml @@ -1146,6 +1146,120 @@ let test_global_map_redefinition_fix () = () +(** Tests for tail call fallback return fix. + These tests verify that every match arm containing a tail call emits an + explicit fallback return statement so the eBPF verifier can confirm that + all code paths terminate, even when bpf_tail_call() fails at runtime. *) + +(** Unit test: IRReturnTailCall in a constant match arm generates a fallback + return statement with XDP_PASS when the function context is "xdp". *) +let test_tail_call_fallback_constant_arm_xdp () = + let ctx = create_c_context () in + ctx.current_function_context_type <- Some "xdp"; + + let matched_val = make_ir_value (IRVariable "protocol") IRU32 test_pos in + let ctx_arg = make_ir_value (IRVariable "ctx") + (IRPointer (IRStruct ("xdp_md", []), make_bounds_info ())) test_pos in + let arms = [ + { match_pattern = IRConstantPattern + (make_ir_value (IRLiteral (IntLit (Signed64 6L, None))) IRU32 test_pos); + return_action = IRReturnTailCall ("tcp_handler", [ctx_arg], 1); + arm_pos = test_pos }; + { match_pattern = IRDefaultPattern; + return_action = IRReturnValue + (make_ir_value (IRLiteral (IntLit (Signed64 0L, None))) IRU32 test_pos); + arm_pos = test_pos }; + ] in + let instr = make_ir_instruction (IRMatchReturn (matched_val, arms)) test_pos in + generate_c_instruction ctx instr; + + let output = String.concat "\n" ctx.output_lines in + check bool "bpf_tail_call emitted for constant arm" + true (contains_substr output "bpf_tail_call(ctx, &prog_array, 1)"); + check bool "XDP_PASS fallback return present" + true (contains_substr output "return XDP_PASS; /* tail call fallback */"); + check bool "old continue-execution comment absent" + false (contains_substr output "If tail call fails, continue execution"); + () + +(** Unit test: IRReturnTailCall in a default match arm also generates a + fallback return, and TC context gives TC_ACT_OK. *) +let test_tail_call_fallback_default_arm_tc () = + let ctx = create_c_context () in + ctx.current_function_context_type <- Some "tc"; + + let matched_val = make_ir_value (IRVariable "proto") IRU32 test_pos in + let ctx_arg = make_ir_value (IRVariable "ctx") + (IRPointer (IRStruct ("__sk_buff", []), make_bounds_info ())) test_pos in + let arms = [ + { match_pattern = IRConstantPattern + (make_ir_value (IRLiteral (IntLit (Signed64 6L, None))) IRU32 test_pos); + return_action = IRReturnValue + (make_ir_value (IRLiteral (IntLit (Signed64 0L, None))) IRU32 test_pos); + arm_pos = test_pos }; + { match_pattern = IRDefaultPattern; + return_action = IRReturnTailCall ("default_tc_handler", [ctx_arg], 2); + arm_pos = test_pos }; + ] in + let instr = make_ir_instruction (IRMatchReturn (matched_val, arms)) test_pos in + generate_c_instruction ctx instr; + + let output = String.concat "\n" ctx.output_lines in + check bool "bpf_tail_call emitted for default arm" + true (contains_substr output "bpf_tail_call(ctx, &prog_array, 2)"); + check bool "TC_ACT_OK fallback return present for TC context" + true (contains_substr output "return TC_ACT_OK; /* tail call fallback */"); + check bool "XDP_PASS fallback NOT present in TC context" + false (contains_substr output "return XDP_PASS;"); + check bool "old continue-execution comment absent" + false (contains_substr output "If tail call fails, continue execution"); + () + +(** Unit test: IRReturnCall (implicit tail call with index 0) in both constant + and default arms generates fallback returns. Generic context (None) uses + "return 0" as the fallback. *) +let test_return_call_fallback_generic_context () = + let ctx = create_c_context () in + (* current_function_context_type left as None -> generic fallback "0" *) + + let matched_val = make_ir_value (IRVariable "key") IRU32 test_pos in + let ctx_arg = make_ir_value (IRVariable "ctx") + (IRPointer (IRStruct ("generic_ctx", []), make_bounds_info ())) test_pos in + let arms = [ + { match_pattern = IRConstantPattern + (make_ir_value (IRLiteral (IntLit (Signed64 1L, None))) IRU32 test_pos); + return_action = IRReturnCall ("handler_one", [ctx_arg]); + arm_pos = test_pos }; + { match_pattern = IRDefaultPattern; + return_action = IRReturnCall ("handler_default", [ctx_arg]); + arm_pos = test_pos }; + ] in + let instr = make_ir_instruction (IRMatchReturn (matched_val, arms)) test_pos in + generate_c_instruction ctx instr; + + let output = String.concat "\n" ctx.output_lines in + (* Both arms use IRReturnCall which maps to index 0 *) + let bpf_calls = ref 0 in + let search_start = ref 0 in + (try while true do + let pos = Str.search_forward (Str.regexp_string "bpf_tail_call(ctx, &prog_array, 0)") output !search_start in + incr bpf_calls; + search_start := pos + 1 + done with Not_found -> ()); + check bool "bpf_tail_call emitted in constant arm (IRReturnCall)" + true (!bpf_calls >= 1); + check bool "bpf_tail_call emitted in default arm (IRReturnCall)" + true (!bpf_calls >= 2); + check bool "generic fallback return 0 present" + true (contains_substr output "return 0; /* tail call fallback */"); + check bool "XDP_PASS fallback NOT present for generic context" + false (contains_substr output "return XDP_PASS;"); + check bool "TC_ACT_OK fallback NOT present for generic context" + false (contains_substr output "return TC_ACT_OK;"); + check bool "old continue-execution comment absent" + false (contains_substr output "If tail call fails, continue execution"); + () + (** Test suite definition *) let suite = [ @@ -1191,6 +1305,10 @@ let suite = ("eBPF function generation bug fix", `Quick, test_ebpf_function_generation_bug_fix); (* Test to prevent global variable map redefinition regression *) ("Global map redefinition fix", `Quick, test_global_map_redefinition_fix); + (* Tail call fallback return fix - verifier requires explicit return after bpf_tail_call() *) + ("Tail call fallback: constant arm XDP context", `Quick, test_tail_call_fallback_constant_arm_xdp); + ("Tail call fallback: default arm TC context", `Quick, test_tail_call_fallback_default_arm_tc); + ("Tail call fallback: IRReturnCall generic context", `Quick, test_return_call_fallback_generic_context); ] (** Run all tests *)