ソースを参照

make highlight friendly to English (#2417)

### What problem does this PR solve?

#2415

### Type of change

- [x] Performance Improvement
tags/v0.11.0
Kevin Hu 1年前
コミット
9d4bb5767c
コミッターのメールアドレスに関連付けられたアカウントが存在しません
3個のファイルの変更22行の追加20行の削除
  1. 1
    1
      rag/nlp/__init__.py
  2. 19
    17
      rag/nlp/search.py
  3. 2
    2
      rag/utils/__init__.py

+ 1
- 1
rag/nlp/__init__.py ファイルの表示

eng = 0 eng = 0
if not texts: return False if not texts: return False
for t in texts: for t in texts:
if re.match(r"[a-zA-Z]{2,}", t.strip()):
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1 eng += 1
if eng / len(texts) > 0.8: if eng / len(texts) > 0.8:
return True return True

+ 19
- 17
rag/nlp/search.py ファイルの表示



from rag.settings import es_logger from rag.settings import es_logger
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query
from rag.nlp import rag_tokenizer, query, is_english
import numpy as np import numpy as np




ids=self.es.getDocIds(res), ids=self.es.getDocIds(res),
query_vector=q_vec, query_vector=q_vec,
aggregation=aggs, aggregation=aggs,
highlight=self.getHighlight(res),
highlight=self.getHighlight(res, keywords, "content_with_weight"),
field=self.getFields(res, src), field=self.getFields(res, src),
keywords=list(kwds) keywords=list(kwds)
) )
bkts = res["aggregations"]["aggs_" + g]["buckets"] bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts] return [(b["key"], b["doc_count"]) for b in bkts]


def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:
continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t)
r = "".join(r)
return r

def getHighlight(self, res, keywords, fieldnm):
ans = {} ans = {}
for d in res["hits"]["hits"]: for d in res["hits"]["hits"]:
hlts = d.get("highlight") hlts = d.get("highlight")
if not hlts: if not hlts:
continue continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue

txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for w in keywords:
txt = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", txt, flags=re.IGNORECASE|re.MULTILINE)

for t in re.split(r"[.?!;\n]", txt):
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])

return ans return ans


def getFields(self, sres, flds): def getFields(self, sres, flds):

+ 2
- 2
rag/utils/__init__.py ファイルの表示





def rmSpace(txt): def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt, flags=re.IGNORECASE)
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)




def findMaxDt(fnm): def findMaxDt(fnm):

読み込み中…
キャンセル
保存