瀏覽代碼

Refa: token similarity calculations. (#6614)

### What problem does this PR solve?

#6507

### Type of change

- [x] Performance Improvement
tags/v0.18.0
Kevin Hu 7 月之前
父節點
當前提交
0758c04941
No account linked to committer's email address
共有 2 個文件被更改,包括 11 次插入9 次删除
  1. 9
    8
      rag/nlp/query.py
  2. 2
    1
      rag/nlp/search.py

+ 9
- 8
rag/nlp/query.py 查看文件



import logging import logging
import json import json
import math
import re import re
from rag.utils.doc_store_conn import MatchTextExpr
from collections import defaultdict


from rag.utils.doc_store_conn import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym from rag.nlp import rag_tokenizer, term_weight, synonym






def token_similarity(self, atks, btkss): def token_similarity(self, atks, btkss):
def toDict(tks): def toDict(tks):
d = {}
if isinstance(tks, str): if isinstance(tks, str):
tks = tks.split() tks = tks.split()
for t, c in self.tw.weights(tks, preprocess=False):
if t not in d:
d[t] = 0
d = defaultdict(int)
wts = self.tw.weights(tks, preprocess=False)
for i, (t, c) in enumerate(wts):
d[t] += c d[t] += c
return d return d


s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v # * dtwt[k]
s += v * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v
return s / q
q += v * v
return math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))


def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30): def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str): if isinstance(content_tks, str):

+ 2
- 1
rag/nlp/search.py 查看文件

# #
import logging import logging
import re import re
from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass


from rag.settings import TAG_FLD, PAGERANK_FLD from rag.settings import TAG_FLD, PAGERANK_FLD
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = [] ins_tw = []
for i in sres.ids: for i in sres.ids:
content_ltks = sres.field[i][cfield].split()
content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split()))
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t] title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t] question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", []) important_kwd = sres.field[i].get("important_kwd", [])

Loading…
取消
儲存