Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions benchmarks/benchmark/tasks/v1_0/recommendation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,22 @@
from typing import Set, Dict, List, Any


def extract_ids_from_answer(answer: str) -> Set[str]:
"""
Extract all SIDs from answer field

Args:
answer: String containing multiple <|sid_begin|>...<|sid_end|> patterns

Returns:
Set of extracted SIDs

Examples:
>>> extract_ids_from_answer("<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>")
{'123', '456'}
def extract_ids_from_answer(answer: str) -> list[str]:
"""Extract all SIDs from answer field, preserving original order.

Returns a deduplicated list that keeps the first occurrence order.

>>> extract_ids_from_answer("<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>")
['123', '456']
"""
correct_answers = set()
seen: set[str] = set()
correct_answers: list[str] = []
for part in answer.split('<|sid_begin|>'):
if '<|sid_end|>' in part:
sid = part.split('<|sid_end|>')[0].strip()
if sid:
correct_answers.add(sid)
if sid and sid not in seen:
correct_answers.append(sid)
seen.add(sid)
return correct_answers


Expand Down
20 changes: 13 additions & 7 deletions benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,21 @@ def apply_sid_to_pid_strategy(pid_info_list: List[Dict[str, int]], strategy: str
raise ValueError(f"Unknown strategy: {strategy}. Must be 'most_popular_originally', 'most_popular_after_downsampling', or 'random'")


def extract_ids_from_answer(answer: List[int]) -> Set[int]:
"""
Extract all PIDs from answer field (metadata["answer_pid"]) or (metadata["answer_iid"])
def extract_ids_from_answer(answer: list[int]) -> list[int]:
"""Extract all PIDs from answer field, preserving original order.

Examples:
>>> extract_ids_from_answer([123, 456, 789])
{123, 456, 789}
Returns a deduplicated list that keeps the first occurrence order.

>>> extract_ids_from_answer([123, 456, 123, 789])
[123, 456, 789]
"""
return set([pid for pid in answer if pid != 0])
seen: set[int] = set()
correct_answers: list[int] = []
for pid in answer:
if pid != 0 and pid not in seen:
correct_answers.append(pid)
seen.add(pid)
return correct_answers


def extract_first_id_from_answer(answer: List[int]) -> int:
Expand Down