diff --git a/benchmarks/benchmark/tasks/v1_0/recommendation/utils.py b/benchmarks/benchmark/tasks/v1_0/recommendation/utils.py index eb12122..42cf2a3 100644 --- a/benchmarks/benchmark/tasks/v1_0/recommendation/utils.py +++ b/benchmarks/benchmark/tasks/v1_0/recommendation/utils.py @@ -7,26 +7,22 @@ from typing import Set, Dict, List, Any -def extract_ids_from_answer(answer: str) -> Set[str]: - """ - Extract all SIDs from answer field - - Args: - answer: String containing multiple <|sid_begin|>...<|sid_end|> patterns - - Returns: - Set of extracted SIDs - - Examples: - >>> extract_ids_from_answer("<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>") - {'123', '456'} +def extract_ids_from_answer(answer: str) -> list[str]: + """Extract all SIDs from answer field, preserving original order. + + Returns a deduplicated list that keeps the first occurrence order. + + >>> extract_ids_from_answer("<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>") + ['123', '456'] """ - correct_answers = set() + seen: set[str] = set() + correct_answers: list[str] = [] for part in answer.split('<|sid_begin|>'): if '<|sid_end|>' in part: sid = part.split('<|sid_end|>')[0].strip() - if sid: - correct_answers.add(sid) + if sid and sid not in seen: + correct_answers.append(sid) + seen.add(sid) return correct_answers diff --git a/benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py b/benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py index f3c56f1..bc06671 100644 --- a/benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py +++ b/benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py @@ -129,15 +129,21 @@ def apply_sid_to_pid_strategy(pid_info_list: List[Dict[str, int]], strategy: str raise ValueError(f"Unknown strategy: {strategy}. Must be 'most_popular_originally', 'most_popular_after_downsampling', or 'random'") -def extract_ids_from_answer(answer: List[int]) -> Set[int]: - """ - Extract all PIDs from answer field (metadata["answer_pid"]) or (metadata["answer_iid"]) +def extract_ids_from_answer(answer: list[int]) -> list[int]: + """Extract all PIDs from answer field, preserving original order. - Examples: - >>> extract_ids_from_answer([123, 456, 789]) - {123, 456, 789} + Returns a deduplicated list that keeps the first occurrence order. + + >>> extract_ids_from_answer([123, 456, 123, 789]) + [123, 456, 789] """ - return set([pid for pid in answer if pid != 0]) + seen: set[int] = set() + correct_answers: list[int] = [] + for pid in answer: + if pid != 0 and pid not in seen: + correct_answers.append(pid) + seen.add(pid) + return correct_answers def extract_first_id_from_answer(answer: List[int]) -> int: