@@ -139,38 +139,84 @@ def _get_base_type_and_depth(ir_type):
139139 return cur_type , depth
140140
141141
142- def _deref_to_depth (builder , val , target_depth ):
142+ def _deref_to_depth (func , builder , val , target_depth ):
143143 """Dereference a pointer to a certain depth."""
144144
145145 cur_val = val
146- for _ in range (target_depth ):
146+ cur_type = val .type
147+
148+ for depth in range (target_depth ):
147149 if not isinstance (val .type , ir .PointerType ):
148150 logger .error ("Cannot dereference further, non-pointer type" )
149151 return None
150- cur_val = builder .load (cur_val )
152+
153+ # dereference with null check
154+ pointee_type = cur_type .pointee
155+ null_check_block = builder .block
156+ not_null_block = func .append_basic_block (name = f"deref_not_null_{ depth } " )
157+ merge_block = func .append_basic_block (name = f"deref_merge_{ depth } " )
158+
159+ null_ptr = ir .Constant (cur_type , None )
160+ is_not_null = builder .icmp_signed ("!=" , cur_val , null_ptr )
161+ logger .debug (f"Inserted null check for pointer at depth { depth } " )
162+
163+ builder .cbranch (is_not_null , not_null_block , merge_block )
164+
165+ builder .position_at_end (not_null_block )
166+ dereferenced_val = builder .load (cur_val )
167+ logger .debug (f"Dereferenced to depth { depth - 1 } , type: { pointee_type } " )
168+ builder .branch (merge_block )
169+
170+ builder .position_at_end (merge_block )
171+ phi = builder .phi (pointee_type , name = f"deref_result_{ depth } " )
172+
173+ zero_value = (
174+ ir .Constant (pointee_type , 0 )
175+ if isinstance (pointee_type , ir .IntType )
176+ else ir .Constant (pointee_type , None )
177+ )
178+ phi .add_incoming (zero_value , null_check_block )
179+
180+ phi .add_incoming (dereferenced_val , not_null_block )
181+
182+ # Continue with phi result
183+ cur_val = phi
184+ cur_type = pointee_type
151185 return cur_val
152186
153187
154- def _normalize_types (builder , lhs , rhs ):
188+ def _normalize_types (func , builder , lhs , rhs ):
155189 """Normalize types for comparison."""
156190
191+ logger .info (f"Normalizing types: { lhs .type } vs { rhs .type } " )
157192 if isinstance (lhs .type , ir .IntType ) and isinstance (rhs .type , ir .IntType ):
158193 if lhs .type .width < rhs .type .width :
159194 lhs = builder .sext (lhs , rhs .type )
160195 else :
161196 rhs = builder .sext (rhs , lhs .type )
162197 return lhs , rhs
163-
164- logger .error (f"Type mismatch: { lhs .type } vs { rhs .type } " )
165- return None , None
198+ elif not isinstance (lhs .type , ir .PointerType ) and not isinstance (
199+ rhs .type , ir .PointerType
200+ ):
201+ logger .error (f"Type mismatch: { lhs .type } vs { rhs .type } " )
202+ return None , None
203+ else :
204+ lhs_base , lhs_depth = _get_base_type_and_depth (lhs .type )
205+ rhs_base , rhs_depth = _get_base_type_and_depth (rhs .type )
206+ if lhs_base == rhs_base :
207+ if lhs_depth < rhs_depth :
208+ rhs = _deref_to_depth (func , builder , rhs , rhs_depth - lhs_depth )
209+ elif rhs_depth < lhs_depth :
210+ lhs = _deref_to_depth (func , builder , lhs , lhs_depth - rhs_depth )
211+ return _normalize_types (func , builder , lhs , rhs )
166212
167213
168- def _handle_comparator (builder , op , lhs , rhs ):
214+ def _handle_comparator (func , builder , op , lhs , rhs ):
169215 """Handle comparison operations."""
170216
171217 # NOTE: For now assume same types
172218 if lhs .type != rhs .type :
173- lhs , rhs = _normalize_types (builder , lhs , rhs )
219+ lhs , rhs = _normalize_types (func , builder , lhs , rhs )
174220
175221 if lhs is None or rhs is None :
176222 return None
@@ -227,7 +273,7 @@ def _handle_compare(
227273
228274 lhs , _ = lhs
229275 rhs , _ = rhs
230- return _handle_comparator (builder , cond .ops [0 ], lhs , rhs )
276+ return _handle_comparator (func , builder , cond .ops [0 ], lhs , rhs )
231277
232278
233279def convert_to_bool (builder , val ):
0 commit comments