| @@ -1,5 +1,5 @@ | |||
| import re | |||
| from typing import Optional | |||
| from typing import Optional, cast | |||
| class JiebaKeywordTableHandler: | |||
| @@ -8,18 +8,20 @@ class JiebaKeywordTableHandler: | |||
| from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS | |||
| jieba.analyse.default_tfidf.stop_words = STOPWORDS | |||
| jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore | |||
| def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: | |||
| """Extract keywords with JIEBA tfidf.""" | |||
| import jieba # type: ignore | |||
| import jieba.analyse # type: ignore | |||
| keywords = jieba.analyse.extract_tags( | |||
| sentence=text, | |||
| topK=max_keywords_per_chunk, | |||
| ) | |||
| # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. | |||
| keywords = cast(list[str], keywords) | |||
| return set(self._expand_tokens_with_subtokens(keywords)) | |||
| return set(self._expand_tokens_with_subtokens(set(keywords))) | |||
| def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: | |||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||