Skip to content

Commit 7a9836f

Browse files
Wang-Daojiyuan.wangfridayL
authored
Feat/fix palyground bug (#684)
* fix playground bug, internet search judge * fix playground internet bug * modify delete mem * modify tool resp bug in multi cube * fix bug in playground chat handle and search inter * modify prompt * fix bug in playground * fix bug playfround * fix bug * fix code * fix model bug in playground * modify plan b * llm param modify * add logger in playground * modify code * fix bug * modify code * modify code * fix bug * fix search bug in plarground * fixx bug * move schadualr to back * modify pref location * modify fast net search * add tags and new package * modify prompt fix bug * remove nltk due to image promblem --------- Co-authored-by: yuan.wang <yuan.wang@yuanwangdebijibendiannao.local> Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com>
1 parent 4d9aa5b commit 7a9836f

File tree

1 file changed

+110
-14
lines changed

1 file changed

+110
-14
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,122 @@ def __init__(
154154
reader: MemReader instance for processing internet content
155155
max_results: Maximum number of search results to retrieve
156156
"""
157-
import nltk
158-
159-
try:
160-
nltk.download("averaged_perceptron_tagger_eng")
161-
except Exception as err:
162-
raise Exception("Failed to download nltk averaged_perceptron_tagger_eng") from err
163-
try:
164-
nltk.download("stopwords")
165-
except Exception as err:
166-
raise Exception("Failed to download nltk stopwords") from err
167157

168158
from jieba.analyse import TextRank
169-
from rake_nltk import Rake
170159

171160
self.bocha_api = BochaAISearchAPI(access_key, max_results=max_results)
172161
self.embedder = embedder
173162
self.reader = reader
174-
self.en_fast_keywords_extractor = Rake()
175163
self.zh_fast_keywords_extractor = TextRank()
176164

165+
def _extract_tags(self, title: str, content: str, summary: str, parsed_goal=None) -> list[str]:
166+
"""
167+
Extract tags from title, content and summary
168+
169+
Args:
170+
title: Article title
171+
content: Article content
172+
summary: Article summary
173+
parsed_goal: Parsed task goal (optional)
174+
175+
Returns:
176+
List of extracted tags
177+
"""
178+
tags = []
179+
180+
# Add source-based tags
181+
tags.append("bocha_search")
182+
tags.append("news")
183+
184+
# Add content-based tags
185+
text = f"{title} {content} {summary}".lower()
186+
187+
# Simple keyword-based tagging
188+
keywords = {
189+
"economy": [
190+
"economy",
191+
"GDP",
192+
"growth",
193+
"production",
194+
"industry",
195+
"investment",
196+
"consumption",
197+
"market",
198+
"trade",
199+
"finance",
200+
],
201+
"politics": [
202+
"politics",
203+
"government",
204+
"policy",
205+
"meeting",
206+
"leader",
207+
"election",
208+
"parliament",
209+
"ministry",
210+
],
211+
"technology": [
212+
"technology",
213+
"tech",
214+
"innovation",
215+
"digital",
216+
"internet",
217+
"AI",
218+
"artificial intelligence",
219+
"software",
220+
"hardware",
221+
],
222+
"sports": [
223+
"sports",
224+
"game",
225+
"athlete",
226+
"olympic",
227+
"championship",
228+
"tournament",
229+
"team",
230+
"player",
231+
],
232+
"culture": [
233+
"culture",
234+
"education",
235+
"art",
236+
"history",
237+
"literature",
238+
"music",
239+
"film",
240+
"museum",
241+
],
242+
"health": [
243+
"health",
244+
"medical",
245+
"pandemic",
246+
"hospital",
247+
"doctor",
248+
"medicine",
249+
"disease",
250+
"treatment",
251+
],
252+
"environment": [
253+
"environment",
254+
"ecology",
255+
"pollution",
256+
"green",
257+
"climate",
258+
"sustainability",
259+
"renewable",
260+
],
261+
}
262+
263+
for category, words in keywords.items():
264+
if any(word in text for word in words):
265+
tags.append(category)
266+
267+
# Add goal-based tags if available
268+
if parsed_goal and hasattr(parsed_goal, "tags"):
269+
tags.extend(parsed_goal.tags)
270+
271+
return list(set(tags))[:15] # Limit to 15 tags
272+
177273
def retrieve_from_internet(
178274
self, query: str, top_k: int = 10, parsed_goal=None, info=None, mode="fast"
179275
) -> list[TextualMemoryItem]:
@@ -259,9 +355,9 @@ def _process_result(
259355
session_id = info_.pop("session_id", "")
260356
lang = detect_lang(summary)
261357
tags = (
262-
self.zh_fast_keywords_extractor.textrank(summary)[:3]
358+
self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3]
263359
if lang == "zh"
264-
else self.en_fast_keywords_extractor.extract_keywords_from_text(summary)[:3]
360+
else self._extract_tags(title, content, summary)[:3]
265361
)
266362

267363
return [

0 commit comments

Comments
 (0)