Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Как использовать:

1. Скачайте данный репозиторий на ваш пк
2. Установите зафисимости с помощью
2. Установите зависимости с помощью
```
pip install -r requirements.txt
```
Expand All @@ -18,3 +18,4 @@ python parser.py --tg_history_path /path/to/history/file.json --output_path /pat
1. raw.csv - файл с неочищенными данными
2. train.jsonl и test.jsonl - данные готовые для дальнейшей обработки
```
P.S добавлен фильтр контента и сортировка диалога
182 changes: 110 additions & 72 deletions parser.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,133 @@
import json
import re
import sys
from pathlib import Path
from typing import Dict, Any, List, Optional, Set
from export import export_data
import typer
from typing import Dict, Any, List, Optional

import orjson
import pandas as pd
import typer
from datasets import load_dataset

Message = Dict[str, Any]
Context = List[Optional[Message]]

app = typer.Typer()

BAD_WORDS = {"нежелательное_слово1", "нежелательное_слово2", "нежелательное_слово3"}

EMOJI_RE = re.compile("[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF]+", flags=re.UNICODE)


def contains_emoji(text: str) -> bool:
return bool(EMOJI_RE.search(text))


def contains_bad_words(text: str) -> bool:
lowered = text.lower()
return any(bad in lowered for bad in BAD_WORDS)


def is_valid(text: Optional[str], min_len: int = 5, max_len: int = 500) -> bool:
if not text:
return False
if not (min_len <= len(text.strip()) <= max_len):
return False
if contains_bad_words(text):
return False
if contains_emoji(text):
return False
return True


def export_data(path: Path):
print(f"Loading CSV dataset from {path / 'raw.csv'}")
data = load_dataset('csv', data_files={'train': str(path / "raw.csv")})
data = data['train'].train_test_split(test_size=0.2)

def is_pair_clean(sample):
return all(is_valid(sample.get(field)) for field in ['context_1', 'response'])

print("Filtering dataset...")
data = data.filter(is_pair_clean)

print("Saving train.jsonl and test.jsonl")
with open(path / 'train.jsonl', 'wb') as train_file:
for item in data['train']:
train_file.write(orjson.dumps(item, option=orjson.OPT_APPEND_NEWLINE))
with open(path / 'test.jsonl', 'wb') as test_file:
for item in data['test']:
test_file.write(orjson.dumps(item, option=orjson.OPT_APPEND_NEWLINE))

print("Export finished")


@app.command()
def prepare_messages(
tg_history_path: Path = typer.Option(..., help='Path to telegram history json file'),
output_path: Path = typer.Option(..., help='Path to output file'),
output_path: Path = typer.Option(..., help='Path to output directory'),
):
with tg_history_path.open() as messages_file:
messages = json.load(messages_file)['messages']
print(f"Loading telegram history from {tg_history_path}")
with tg_history_path.open(encoding='utfФ-8') as f:
messages = json.load(f).get("messages", [])

print(f"Loaded {len(messages)} messages")

contexts = _create_contexts(messages)
contexts = _transform_contexts(contexts)
transformed = _transform_contexts(contexts)

print(f"Prepared {len(transformed)} contexts")

contexts_df = pd.DataFrame.from_records(contexts)
contexts_df.drop_duplicates(inplace=True)
contexts_df.to_csv(output_path + '/raw.csv', index=False)
output_path.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame.from_records(transformed)
df.drop_duplicates(inplace=True)
csv_path = output_path / "raw.csv"
df.to_csv(csv_path, index=False)

print(f"Saved to {csv_path}")
export_data(output_path)


def _create_contexts(messages: List[Message]) -> List[Context]:
replies_threads = {}
id_to_message = {}
for message in messages:
id_to_message[message['id']] = message
if 'reply_to_message_id' in message:
replies_threads[message['reply_to_message_id']] = message['id']

contexts = []
cur_context = _create_default_list()
visited_replies = set()

for message in messages:
if (
message['type'] != 'message' or
not message['text'] or
not isinstance(message['text'], str) or
message['id'] in visited_replies
):
continue
contexts: List[Context] = []
current_context: List[Optional[Message]] = []

if 'forwarded_from' in message and cur_context:
contexts.append(cur_context)
cur_context = _create_default_list()
continue
last_author = None

if message['id'] in replies_threads:
contexts.append(cur_context)
cur_context = _create_default_list()
_resolve_thread(contexts, replies_threads, visited_replies, id_to_message, message)
for msg in messages:
if msg.get("type") != "message":
continue

if cur_context[-1] and message['from_id'] == cur_context[-1]['from_id']:
contexts[-1][-1]['text'] += '\n' + message["text"]
text = msg.get("text")
if not text or not isinstance(text, (str, list)):
continue

cur_context.pop(0)
cur_context.append(message)
contexts.append(cur_context.copy())

return contexts

# Объединяем сообщения одного автора
if last_author == msg.get("from_id") and current_context:
if isinstance(current_context[-1]['text'], list):
if isinstance(text, list):
current_context[-1]['text'].extend(text)
else:
current_context[-1]['text'].append(text)
else:
if isinstance(text, list):
combined = ''.join(t['text'] if isinstance(t, dict) else t for t in text)
else:
combined = text
current_context[-1]['text'] += '\n' + combined
continue

def _resolve_thread(
contexts: List[Context],
replies_threads: Dict[int, int],
visited_replies: Set[int],
id_to_message: Dict[int, Message],
message: Message,
) -> None:
cur_context = _create_default_list()
cur_id = message['id']
# Новый автор — добавляем как новый шаг контекста
current_context.append(msg)
last_author = msg.get("from_id")

while cur_id:
cur_context.pop(0)
cur_context.append(id_to_message[cur_id])
contexts.append(cur_context.copy())
if len(current_context) == 4:
contexts.append(current_context.copy())
current_context.pop(0)

visited_replies.add(cur_id)
cur_id = replies_threads.get(cur_id)
return contexts


def _transform_contexts(contexts: List[Context]) -> List[Dict[str, Optional[str]]]:
Expand All @@ -104,17 +146,13 @@ def _transform_context(context: Context) -> Dict[str, Optional[str]]:
def _transform_message(message: Optional[Message]) -> Optional[str]:
if not message:
return None

if isinstance(message['text'], list):
texts = [text['text'] if isinstance(text, dict) else text for text in message['text']]
message['text'] = ''.join(texts)

return message['text']


def _create_default_list(message: Optional[Message] = None) -> List[Optional[Message]]:
return [None, None, None, message]
text = message.get("text")
if isinstance(text, list):
return ''.join(t["text"] if isinstance(t, dict) else t for t in text)
return text


if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "prepare-messages":
sys.argv = sys.argv[1:]
app()