-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
261 lines (232 loc) · 11.2 KB
/
main.py
File metadata and controls
261 lines (232 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import logging
import argparse
from utils.helpers import save_results
from utils.beir_loader import load_beir_datasets
from evaluation.retrieval_quality_evaluator_clean import evaluate_retrieval_quality, plot_retrieval_quality_metrics
from evaluation.chunk_size_evaluator import evaluate_chunk_sizes
# from evaluation.scoring_system import calculate_scores, plot_scores
from chunking_methods import (
PercentileChunker, StdDeviationChunker, InterquartileChunker,
GradientChunker, StructuralChunker, FixedLenChunker
)
from langchain_huggingface import HuggingFaceEmbeddings
import json
import os
from typing import Dict, List, Tuple
# -------------------------------
# Utilities
# -------------------------------
METHOD_ALIASES = {
# cli arg (lower) -> Canonical ClassName (Pascal) & dir_name (lower)
'percentile': ('Percentile', 'percentile', PercentileChunker),
'stddeviation': ('StdDeviation', 'stddeviation', StdDeviationChunker),
'interquartile': ('Interquartile', 'interquartile', InterquartileChunker), # FIXED: no "Interquantile"
'recursive': ('FixedLen', 'fixedlen', FixedLenChunker),
'gradient': ('Gradient', 'gradient', GradientChunker),
'structural': ('Structural', 'structural', StructuralChunker),
}
def normalize_methods(methods: List[str]) -> List[Tuple[str, str, object]]:
"""args.method 리스트를 (ClassName, dir_name, ChunkerClass) 튜플 리스트로 정규화"""
norm = []
for m in methods:
key = m.lower()
if key not in METHOD_ALIASES:
raise ValueError(f"Unknown method: {m}")
norm.append(METHOD_ALIASES[key])
return norm
def make_chunk_dir(dir_name: str, domain: str) -> str:
return os.path.join('data', 'chunks', dir_name, domain)
def chunks_exist(datasets: Dict, methods: List[str]) -> bool:
"""선택된 메서드/데이터셋에 대해 최소 1개 이상의 청크 파일 존재 여부"""
valid_domains = list(datasets.keys())
normalized = normalize_methods(methods)
for _, dir_name, _ in normalized:
for domain in valid_domains:
path = make_chunk_dir(dir_name, domain)
if not (os.path.isdir(path) and any(f.endswith('_chunks.json') for f in os.listdir(path))):
return False
return True
def load_existing_chunks(datasets: Dict, methods: List[str]) -> Dict:
"""선택된 메서드/데이터셋에 대한 기존 청크 로드"""
results = {}
valid_domains = list(datasets.keys())
normalized = normalize_methods(methods)
for class_name, dir_name, _ in normalized:
results[class_name] = {}
for domain in valid_domains:
path = make_chunk_dir(dir_name, domain)
results[class_name][domain] = []
if not os.path.isdir(path):
print(f"[WARN] Chunk path missing: {path}")
continue
chunk_files = [f for f in os.listdir(path) if f.endswith('_chunks.json')]
print(f"Loading {len(chunk_files)} files from {path}")
for cf in chunk_files:
fpath = os.path.join(path, cf)
try:
with open(fpath, 'r') as f:
chunks = json.load(f)
# 표준화: dict만 허용하고 content 필드 보정
normalized_chunks = []
for ch in chunks:
if isinstance(ch, dict):
if 'content' not in ch:
ch = {'content': ch.get('text', '')}
else:
ch = {'content': str(ch)}
normalized_chunks.append(ch)
results[class_name][domain].extend(normalized_chunks)
except Exception as e:
print(f"[WARN] Failed to load {cf}: {e}")
# 요약
for class_name, per_domain in results.items():
for domain, items in per_domain.items():
print(f" {class_name} - {domain}: {len(items)} chunks loaded")
return results
def save_chunks(dir_name: str, domain: str, doc_id: str, chunks: List[dict]):
"""청크를 파일로 저장 (dir_name은 lower)"""
chunk_path = make_chunk_dir(dir_name, domain)
os.makedirs(chunk_path, exist_ok=True)
# 파일명 안전화
safe_doc_id = (
doc_id.replace('<', '').replace('>', '')
.replace(':', '_').replace('/', '_').replace('\\', '_')
)
# 표준 형식으로 변환
formatted = []
for ch in chunks:
if isinstance(ch, dict):
if 'content' not in ch:
ch = {'content': ch.get('text', '')}
else:
ch = {'content': str(ch)}
formatted.append(ch)
out_path = os.path.join(chunk_path, f'{safe_doc_id}_chunks.json')
with open(out_path, 'w') as f:
json.dump(formatted, f, indent=2)
# -------------------------------
# Main
# -------------------------------
def main():
parser = argparse.ArgumentParser(description='BEIR Chunking Benchmark')
parser.add_argument('--mode', choices=['chunk', 'eval'], default='chunk',
help='Mode: "chunk" for chunking+evaluation, "eval" for evaluation only')
parser.add_argument('--datasets', nargs='+',
default=['nfcorpus', 'scifact', 'arguana', 'scidocs', 'fiqa'],
help='BEIR datasets to use (split=test)')
parser.add_argument('--method', nargs='+',
choices=list(METHOD_ALIASES.keys()),
default=list(METHOD_ALIASES.keys()),
help='Chunking methods to use')
parser.add_argument('--nlp', choices=['kiwi', 'spacy'], default='spacy',
help='Tokenizer/NLP backend to use inside chunkers')
args = parser.parse_args()
# 로깅
logging.basicConfig(
filename='logs/main.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
print("=== BEIR Chunking Benchmark ===")
print(f"Mode: {args.mode}")
print(f"Datasets (requested): {args.datasets}")
print(f"Methods: {args.method}")
try:
# 데이터셋 로드
print("Loading BEIR datasets...")
beir_datasets_list = args.datasets
datasets, ground_truths = load_beir_datasets(
data_dir='data',
datasets=beir_datasets_list,
split='test'
)
if not datasets:
logger.warning("No BEIR datasets were loaded. Exiting.")
print("[ERROR] No BEIR datasets were loaded. Please check availability under ./data/")
return
# 로드 결과 요약
loaded_domains = list(datasets.keys())
print(f"Loaded BEIR datasets: {loaded_domains}")
print(f"Ground truths available for: {list(ground_truths.keys())}")
# Ground truth 저장 (후속 평가 모듈에서 재사용)
from utils.beir_loader import BEIRDataLoader
beir_loader = BEIRDataLoader('data')
os.makedirs('config', exist_ok=True)
beir_loader.save_ground_truths(ground_truths, 'config/beir_ground_truths.json')
# 임베딩 초기화
hf_embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
encode_kwargs={
"batch_size": 5012*2,
"normalize_embeddings": True,
}
)
print("HuggingFace Embeddings initialized.")
# 메서드 정규화
normalized_methods = normalize_methods(args.method)
print("Selected chunking methods:",
[cn for (cn, _, _) in normalized_methods])
# 기존 청크 재사용 또는 새로 생성
if args.mode == 'chunk':
# chunk 모드: 이미 있으면 재사용, 없으면 생성
if chunks_exist(datasets, args.method):
print("Chunks already exist for selected methods/datasets. Loading...")
existing_chunks = load_existing_chunks(datasets, args.method)
else:
existing_chunks = {}
for class_name, dir_name, Chunker in normalized_methods:
logger.info(f"Applying chunking method: {class_name}")
print(f"Applying chunking method: {class_name}")
chunker = Chunker(embeddings=hf_embeddings, nlp_backend=args.nlp)
existing_chunks[class_name] = {}
for domain, documents in datasets.items():
print(f" Processing domain: {domain} ({len(documents)} docs)")
existing_chunks[class_name][domain] = []
for i, (doc_id, text) in enumerate(documents.items(), 1):
try:
chunks = chunker.split_text(text, source=domain, file_name=doc_id)
except Exception as e:
print(f"[WARN] Chunking failed ({domain}/{doc_id}): {e}")
chunks = []
existing_chunks[class_name][domain].extend(chunks)
# 저장
save_chunks(dir_name, domain, doc_id, chunks)
if i % 100 == 0:
print(f" Processed {i}/{len(documents)}...")
logger.info(f"Completed chunking for method: {class_name}")
print(f"Completed chunking for method: {class_name}")
else:
print("Evaluation-only mode: loading existing chunks...")
existing_chunks = load_existing_chunks(datasets, args.method)
if not any(len(v) for v in existing_chunks.values()):
print("[ERROR] No existing chunks found. Run with --mode chunk first.")
logger.error("No existing chunks found for evaluation.")
return
# 청크 크기 평가
print("Evaluating chunk sizes...")
chunk_size_metrics = evaluate_chunk_sizes(methods=args.method, datasets=args.datasets)
print("Chunk size evaluation completed.")
# 검색 품질 평가ㅊ
print("Evaluating retrieval quality...")
retrieval_metrics = evaluate_retrieval_quality(existing_chunks, ground_truths, embeddings=hf_embeddings)
print("Retrieval quality evaluation completed.")
# 결과 저장
print("Saving results...")
# save_results(chunk_size_metrics, retrieval_metrics, scores)
save_results(chunk_size_metrics, retrieval_metrics)
print("Results saved successfully.")
# 시각화
print("Plotting retrieval quality metrics...")
plot_retrieval_quality_metrics('results/retrieval_quality_metrics.json')
print("Plotting scores...")
# plot_scores('results/scores.json')
logger.info("Completed BEIR Chunking Benchmark successfully.")
print("Completed BEIR Chunking Benchmark successfully.")
except Exception as e:
logging.error(f"An error occurred during benchmarking: {e}")
print(f"An error occurred during benchmarking: {e}")
raise
if __name__ == "__main__":
main()