Skip to content

Commit 27c9f55

Browse files
committed
updates mock tokenizer
1 parent fede3e7 commit 27c9f55

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

tests/unit/test_data_process.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def _mock_apply_chat_template(
6767
messages: t.List[Message],
6868
tokenize: bool = True,
6969
add_special_tokens: bool = True,
70-
) -> t.Union[str, t.List[int]]:
70+
return_dict: bool = False,
71+
**kwargs,
72+
) -> t.Union[str, t.List[int], t.Dict[str, t.Any]]:
7173
"""Mock implementation of apply_chat_template."""
7274
template_tokens = []
7375

@@ -91,10 +93,14 @@ def _mock_apply_chat_template(
9193
]
9294
template_tokens.extend(reasoning_tokens)
9395

94-
if tokenize:
95-
return template_tokens
96-
else:
97-
return " ".join([f"token_{t}" for t in template_tokens])
96+
result = (
97+
template_tokens
98+
if tokenize
99+
else " ".join([f"token_{t}" for t in template_tokens])
100+
)
101+
if return_dict:
102+
return {"input_ids": result}
103+
return result
98104

99105
def test_single_turn_assistant_only_content(self):
100106
"""Test basic single-turn conversation with assistant content only."""
@@ -555,7 +561,9 @@ def _mock_apply_chat_template(
555561
messages: t.List[Message],
556562
tokenize: bool = True,
557563
add_special_tokens: bool = True,
558-
) -> t.Union[str, t.List[int]]:
564+
return_dict: bool = False,
565+
**kwargs,
566+
) -> t.Union[str, t.List[int], t.Dict[str, t.Any]]:
559567
"""Mock implementation of apply_chat_template."""
560568
template_str = ""
561569
for msg in messages:
@@ -566,10 +574,14 @@ def _mock_apply_chat_template(
566574
template_str += msg["reasoning_content"]
567575
template_str += "\n"
568576

569-
if tokenize:
570-
return [hash(template_str) % 1000 for _ in range(len(template_str.split()))]
571-
else:
572-
return template_str
577+
result = (
578+
[hash(template_str) % 1000 for _ in range(len(template_str.split()))]
579+
if tokenize
580+
else template_str
581+
)
582+
if return_dict:
583+
return {"input_ids": result}
584+
return result
573585

574586
def test_wrap_masked_messages_with_reasoning_content(self):
575587
"""Test that wrap_masked_messages correctly wraps both content and reasoning_content."""

0 commit comments

Comments
 (0)