Skip to content

Commit 6b59980

Browse files
committed
Add null checks for pointer derefs to avoid map_value_or_null verifier errors
1 parent 3f9604a commit 6b59980

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

pythonbpf/expr_pass.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,38 +139,84 @@ def _get_base_type_and_depth(ir_type):
139139
return cur_type, depth
140140

141141

142-
def _deref_to_depth(builder, val, target_depth):
142+
def _deref_to_depth(func, builder, val, target_depth):
143143
"""Dereference a pointer to a certain depth."""
144144

145145
cur_val = val
146-
for _ in range(target_depth):
146+
cur_type = val.type
147+
148+
for depth in range(target_depth):
147149
if not isinstance(val.type, ir.PointerType):
148150
logger.error("Cannot dereference further, non-pointer type")
149151
return None
150-
cur_val = builder.load(cur_val)
152+
153+
# dereference with null check
154+
pointee_type = cur_type.pointee
155+
null_check_block = builder.block
156+
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
157+
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
158+
159+
null_ptr = ir.Constant(cur_type, None)
160+
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
161+
logger.debug(f"Inserted null check for pointer at depth {depth}")
162+
163+
builder.cbranch(is_not_null, not_null_block, merge_block)
164+
165+
builder.position_at_end(not_null_block)
166+
dereferenced_val = builder.load(cur_val)
167+
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
168+
builder.branch(merge_block)
169+
170+
builder.position_at_end(merge_block)
171+
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
172+
173+
zero_value = (
174+
ir.Constant(pointee_type, 0)
175+
if isinstance(pointee_type, ir.IntType)
176+
else ir.Constant(pointee_type, None)
177+
)
178+
phi.add_incoming(zero_value, null_check_block)
179+
180+
phi.add_incoming(dereferenced_val, not_null_block)
181+
182+
# Continue with phi result
183+
cur_val = phi
184+
cur_type = pointee_type
151185
return cur_val
152186

153187

154-
def _normalize_types(builder, lhs, rhs):
188+
def _normalize_types(func, builder, lhs, rhs):
155189
"""Normalize types for comparison."""
156190

191+
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
157192
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
158193
if lhs.type.width < rhs.type.width:
159194
lhs = builder.sext(lhs, rhs.type)
160195
else:
161196
rhs = builder.sext(rhs, lhs.type)
162197
return lhs, rhs
163-
164-
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
165-
return None, None
198+
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
199+
rhs.type, ir.PointerType
200+
):
201+
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
202+
return None, None
203+
else:
204+
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
205+
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
206+
if lhs_base == rhs_base:
207+
if lhs_depth < rhs_depth:
208+
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
209+
elif rhs_depth < lhs_depth:
210+
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
211+
return _normalize_types(func, builder, lhs, rhs)
166212

167213

168-
def _handle_comparator(builder, op, lhs, rhs):
214+
def _handle_comparator(func, builder, op, lhs, rhs):
169215
"""Handle comparison operations."""
170216

171217
# NOTE: For now assume same types
172218
if lhs.type != rhs.type:
173-
lhs, rhs = _normalize_types(builder, lhs, rhs)
219+
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
174220

175221
if lhs is None or rhs is None:
176222
return None
@@ -227,7 +273,7 @@ def _handle_compare(
227273

228274
lhs, _ = lhs
229275
rhs, _ = rhs
230-
return _handle_comparator(builder, cond.ops[0], lhs, rhs)
276+
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
231277

232278

233279
def convert_to_bool(builder, val):

0 commit comments

Comments
 (0)