diff --git a/README.md b/README.md index 47533e7..e283c4f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # Как использовать: 1. Скачайте данный репозиторий на ваш пк -2. Установите зафисимости с помощью +2. Установите зависимости с помощью ``` pip install -r requirements.txt ``` @@ -18,3 +18,4 @@ python parser.py --tg_history_path /path/to/history/file.json --output_path /pat 1. raw.csv - файл с неочищенными данными 2. train.jsonl и test.jsonl - данные готовые для дальнейшей обработки ``` +P.S добавлен фильтр контента и сортировка диалога diff --git a/parser.py b/parser.py index 9c86e82..d7027e1 100644 --- a/parser.py +++ b/parser.py @@ -1,91 +1,133 @@ import json +import re +import sys from pathlib import Path -from typing import Dict, Any, List, Optional, Set -from export import export_data -import typer +from typing import Dict, Any, List, Optional + +import orjson import pandas as pd +import typer +from datasets import load_dataset Message = Dict[str, Any] Context = List[Optional[Message]] + app = typer.Typer() +BAD_WORDS = {"нежелательное_слово1", "нежелательное_слово2", "нежелательное_слово3"} + +EMOJI_RE = re.compile("[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF]+", flags=re.UNICODE) + + +def contains_emoji(text: str) -> bool: + return bool(EMOJI_RE.search(text)) + + +def contains_bad_words(text: str) -> bool: + lowered = text.lower() + return any(bad in lowered for bad in BAD_WORDS) + + +def is_valid(text: Optional[str], min_len: int = 5, max_len: int = 500) -> bool: + if not text: + return False + if not (min_len <= len(text.strip()) <= max_len): + return False + if contains_bad_words(text): + return False + if contains_emoji(text): + return False + return True + + +def export_data(path: Path): + print(f"Loading CSV dataset from {path / 'raw.csv'}") + data = load_dataset('csv', data_files={'train': str(path / "raw.csv")}) + data = data['train'].train_test_split(test_size=0.2) + + def is_pair_clean(sample): + return all(is_valid(sample.get(field)) for field in ['context_1', 'response']) + + print("Filtering dataset...") + data = data.filter(is_pair_clean) + + print("Saving train.jsonl and test.jsonl") + with open(path / 'train.jsonl', 'wb') as train_file: + for item in data['train']: + train_file.write(orjson.dumps(item, option=orjson.OPT_APPEND_NEWLINE)) + with open(path / 'test.jsonl', 'wb') as test_file: + for item in data['test']: + test_file.write(orjson.dumps(item, option=orjson.OPT_APPEND_NEWLINE)) + + print("Export finished") + @app.command() def prepare_messages( tg_history_path: Path = typer.Option(..., help='Path to telegram history json file'), - output_path: Path = typer.Option(..., help='Path to output file'), + output_path: Path = typer.Option(..., help='Path to output directory'), ): - with tg_history_path.open() as messages_file: - messages = json.load(messages_file)['messages'] + print(f"Loading telegram history from {tg_history_path}") + with tg_history_path.open(encoding='utfФ-8') as f: + messages = json.load(f).get("messages", []) + + print(f"Loaded {len(messages)} messages") contexts = _create_contexts(messages) - contexts = _transform_contexts(contexts) + transformed = _transform_contexts(contexts) + + print(f"Prepared {len(transformed)} contexts") - contexts_df = pd.DataFrame.from_records(contexts) - contexts_df.drop_duplicates(inplace=True) - contexts_df.to_csv(output_path + '/raw.csv', index=False) + output_path.mkdir(parents=True, exist_ok=True) + df = pd.DataFrame.from_records(transformed) + df.drop_duplicates(inplace=True) + csv_path = output_path / "raw.csv" + df.to_csv(csv_path, index=False) + + print(f"Saved to {csv_path}") export_data(output_path) + def _create_contexts(messages: List[Message]) -> List[Context]: - replies_threads = {} - id_to_message = {} - for message in messages: - id_to_message[message['id']] = message - if 'reply_to_message_id' in message: - replies_threads[message['reply_to_message_id']] = message['id'] - - contexts = [] - cur_context = _create_default_list() - visited_replies = set() - - for message in messages: - if ( - message['type'] != 'message' or - not message['text'] or - not isinstance(message['text'], str) or - message['id'] in visited_replies - ): - continue + contexts: List[Context] = [] + current_context: List[Optional[Message]] = [] - if 'forwarded_from' in message and cur_context: - contexts.append(cur_context) - cur_context = _create_default_list() - continue + last_author = None - if message['id'] in replies_threads: - contexts.append(cur_context) - cur_context = _create_default_list() - _resolve_thread(contexts, replies_threads, visited_replies, id_to_message, message) + for msg in messages: + if msg.get("type") != "message": continue - - if cur_context[-1] and message['from_id'] == cur_context[-1]['from_id']: - contexts[-1][-1]['text'] += '\n' + message["text"] + text = msg.get("text") + if not text or not isinstance(text, (str, list)): continue - cur_context.pop(0) - cur_context.append(message) - contexts.append(cur_context.copy()) - - return contexts - + # Объединяем сообщения одного автора + if last_author == msg.get("from_id") and current_context: + if isinstance(current_context[-1]['text'], list): + if isinstance(text, list): + current_context[-1]['text'].extend(text) + else: + current_context[-1]['text'].append(text) + else: + if isinstance(text, list): + combined = ''.join(t['text'] if isinstance(t, dict) else t for t in text) + else: + combined = text + current_context[-1]['text'] += '\n' + combined + continue -def _resolve_thread( - contexts: List[Context], - replies_threads: Dict[int, int], - visited_replies: Set[int], - id_to_message: Dict[int, Message], - message: Message, -) -> None: - cur_context = _create_default_list() - cur_id = message['id'] + # Новый автор — добавляем как новый шаг контекста + current_context.append(msg) + last_author = msg.get("from_id") - while cur_id: - cur_context.pop(0) - cur_context.append(id_to_message[cur_id]) - contexts.append(cur_context.copy()) + if len(current_context) == 4: + contexts.append(current_context.copy()) + current_context.pop(0) - visited_replies.add(cur_id) - cur_id = replies_threads.get(cur_id) + return contexts def _transform_contexts(contexts: List[Context]) -> List[Dict[str, Optional[str]]]: @@ -104,17 +146,13 @@ def _transform_context(context: Context) -> Dict[str, Optional[str]]: def _transform_message(message: Optional[Message]) -> Optional[str]: if not message: return None - - if isinstance(message['text'], list): - texts = [text['text'] if isinstance(text, dict) else text for text in message['text']] - message['text'] = ''.join(texts) - - return message['text'] - - -def _create_default_list(message: Optional[Message] = None) -> List[Optional[Message]]: - return [None, None, None, message] + text = message.get("text") + if isinstance(text, list): + return ''.join(t["text"] if isinstance(t, dict) else t for t in text) + return text -if __name__ == '__main__': +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "prepare-messages": + sys.argv = sys.argv[1:] app()