浏览代码

make highlight friendly to English (#2417)

### What problem does this PR solve?

#2415

### Type of change

- [x] Performance Improvement
tags/v0.11.0
Kevin Hu 1年前
父节点
当前提交
9d4bb5767c
没有帐户链接到提交者的电子邮件
共有 3 个文件被更改,包括 22 次插入20 次删除
  1. 1
    1
      rag/nlp/__init__.py
  2. 19
    17
      rag/nlp/search.py
  3. 2
    2
      rag/utils/__init__.py

+ 1
- 1
rag/nlp/__init__.py 查看文件

@@ -214,7 +214,7 @@ def is_english(texts):
eng = 0
if not texts: return False
for t in texts:
if re.match(r"[a-zA-Z]{2,}", t.strip()):
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1
if eng / len(texts) > 0.8:
return True

+ 19
- 17
rag/nlp/search.py 查看文件

@@ -24,7 +24,7 @@ from dataclasses import dataclass

from rag.settings import es_logger
from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query
from rag.nlp import rag_tokenizer, query, is_english
import numpy as np


@@ -164,7 +164,7 @@ class Dealer:
ids=self.es.getDocIds(res),
query_vector=q_vec,
aggregation=aggs,
highlight=self.getHighlight(res),
highlight=self.getHighlight(res, keywords, "content_with_weight"),
field=self.getFields(res, src),
keywords=list(kwds)
)
@@ -175,26 +175,28 @@ class Dealer:
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]

def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:
continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t)
r = "".join(r)
return r

def getHighlight(self, res, keywords, fieldnm):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue

txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for w in keywords:
txt = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", txt, flags=re.IGNORECASE|re.MULTILINE)

for t in re.split(r"[.?!;\n]", txt):
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])

return ans

def getFields(self, sres, flds):

+ 2
- 2
rag/utils/__init__.py 查看文件

@@ -32,8 +32,8 @@ def singleton(cls, *args, **kw):


def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt, flags=re.IGNORECASE)
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)


def findMaxDt(fnm):

正在加载...
取消
保存