Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/prompts/templates/line_localization.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
I have access to a python code repository in the directory {{ instance.repo_path }}.

You will be provided with an issue description and you will need to act as a code search agent to find the relevant lines in the code files where we need to make changes so that the issue is resolved.
NOTE: You do not need to solve the issue, all you need to do is find the relevant lines of code. Your output will be used to guide another agent to solve the issue.
Only use grep to find and please output your answer as just a list of the relevant lines in the code files that are directly related to the issue description in the format <answer>[(file1, (start_line, end_line)), (file2, (start_line, end_line)), (file2, (start_line, end_line)), ...]</answer> using the message parameter of the line_result tool.


Issue description:
{{ instance.problem_statement }}
73 changes: 73 additions & 0 deletions src/rewards/line_localization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import ast

def iou_score(predicted_lines, true_lines):
# predicted_lines / true_lines:
# [(file_path, (start_line, end_line)), ...]
if not predicted_lines and not true_lines:
return 1.0

def to_per_file_intervals(spans):
per_file = {}
for file_path, (start, end) in spans:
if file_path not in per_file:
per_file[file_path] = []
per_file[file_path].append((start, end))
return per_file

def merge_intervals(intervals):
if not intervals:
return []
intervals = sorted(intervals)
merged = [intervals[0]]
for start, end in intervals[1:]:
last_start, last_end = merged[-1]
if start <= last_end:
merged[-1] = (last_start, max(last_end, end))
else:
merged.append((start, end))
return merged

def total_length(intervals):
return sum(end - start + 1 for start, end in intervals)

def intersection_length(a, b):
i = 0
j = 0
total = 0
while i < len(a) and j < len(b):
s1, e1 = a[i]
s2, e2 = b[j]
start = max(s1, s2)
end = min(e1, e2)
if start <= end:
total += end - start + 1
if e1 < e2:
i += 1
else:
j += 1
return total

pred_per_file = to_per_file_intervals(predicted_lines)
true_per_file = to_per_file_intervals(true_lines)

pred_total = 0
true_total = 0
inter_total = 0

all_files = set(pred_per_file) | set(true_per_file)
for file_path in all_files:
pred_intervals = merge_intervals(pred_per_file.get(file_path, []))
true_intervals = merge_intervals(true_per_file.get(file_path, []))
pred_total += total_length(pred_intervals)
true_total += total_length(true_intervals)
if pred_intervals and true_intervals:
inter_total += intersection_length(pred_intervals, true_intervals)

union_total = pred_total + true_total - inter_total
return inter_total / union_total if union_total > 0 else 0.0

def line_localization_reward(final_message, instance):
# Expected format: <answer>[(file1, (start_line, end_line)), (file2, (start_line, end_line)), (file2, (start_line, end_line)), ...]</answer>
pred = ast.literal_eval(final_message.split("<answer>")[1].split("</answer>")[0])
true = ast.literal_eval(instance["target"])
return iou_score(pred, true)
13 changes: 13 additions & 0 deletions src/tools/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,16 @@ def result(file_paths: list[str]) -> str:
Success message
"""
return "Success"


def line_result(line_paths: list[tuple[str, tuple[int, int]]]) -> str:
"""
Return the final list of relevant line paths.

Args:
line_paths: List of line paths

Returns:
Success message
"""
return "Success"
2 changes: 1 addition & 1 deletion src/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def flush_hunk():
if current_file is None or hunk_old_start is None:
return
count = hunk_old_count if hunk_old_count is not None else 1
results.setdefault(current_file, []).append([hunk_old_start, count])
results.setdefault(current_file, []).append([hunk_old_start, hunk_old_start + count])
# Reset hunk state
in_hunk = False
hunk_old_start = None
Expand Down