浏览代码

fix bug about fetching knowledge graph (#3394)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.14.0
Kevin Hu 11 个月前
父节点
当前提交
4caf932808
没有帐户链接到提交者的电子邮件

+ 1
- 4
api/apps/chunk_app.py 查看文件

@login_required @login_required
def knowledge_graph(): def knowledge_graph():
doc_id = request.args["doc_id"] doc_id = request.args["doc_id"]
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = { req = {
"doc_ids":[doc_id], "doc_ids":[doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

+ 2
- 2
api/apps/document_app.py 查看文件

@manager.route('/parse', methods=['POST']) @manager.route('/parse', methods=['POST'])
@login_required @login_required
def parse(): def parse():
url = request.json.get("url")
url = request.json.get("url") if request.json else ""
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(
options.add_argument('--disable-dev-shm-usage') options.add_argument('--disable-dev-shm-usage')
driver = Chrome(options=options) driver = Chrome(options=options)
driver.get(url) driver.get(url)
sections = RAGFlowHtmlParser()("", binary=driver.page_source)
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
return get_json_result(data="\n".join(sections)) return get_json_result(data="\n".join(sections))


if 'file' not in request.files: if 'file' not in request.files:

+ 37
- 0
api/db/services/file_service.py 查看文件

# #
import re import re
import os import os
from concurrent.futures import ThreadPoolExecutor

from flask_login import current_user from flask_login import current_user
from peewee import fn from peewee import fn




return err, files return err, files


@staticmethod
def parse_docs(file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email

def dummy(prog=None, msg=""):
pass

FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
kwargs = {
"lang": "English",
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": user_id
}
filetype = filename_type(file.filename)
blob = file.read()
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))

res = []
for th in threads:
res.append("\n".join([ck["content_with_weight"] for ck in th.result()]))

return "\n\n".join(res)

@staticmethod @staticmethod
def get_parser(doc_type, filename, default): def get_parser(doc_type, filename, default):
if doc_type == FileType.VISUAL: if doc_type == FileType.VISUAL:

+ 1
- 1
api/db/services/knowledgebase_service.py 查看文件

cls.model.id, cls.model.id,
] ]
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id) kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
kb_ids = [kb["id"] for kb in kbs]
kb_ids = [kb.id for kb in kbs]
return kb_ids return kb_ids


@classmethod @classmethod

+ 17
- 13
deepdoc/parser/txt_parser.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import re

from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.nlp import num_tokens_from_string from rag.nlp import num_tokens_from_string


def add_chunk(t): def add_chunk(t):
nonlocal cks, tk_nums, delimiter nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t) tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if tk_nums[-1] > chunk_token_num: if tk_nums[-1] > chunk_token_num:
cks.append(t) cks.append(t)
tk_nums.append(tnum) tk_nums.append(tnum)
cks[-1] += t cks[-1] += t
tk_nums[-1] += tnum tk_nums[-1] += tnum


s, e = 0, 1
while e < len(txt):
if txt[e] in delimiter:
add_chunk(txt[s: e + 1])
s = e + 1
e = s + 1
else:
e += 1
if s < e:
add_chunk(txt[s: e + 1])
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiter[s: f]))
s = t
if s < len(delimiter):
dels.extend(list(delimiter[s:]))
dels = [re.escape(d) for d in delimiter if d]
dels = [d for d in dels if d]
dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt)
for sec in secs: add_chunk(sec)


return [[c,""] for c in cks]
return [[c, ""] for c in cks]

+ 62
- 42
rag/utils/es_conn.py 查看文件

from rag.utils import singleton from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
import polars as pl import polars as pl
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer




try: try:
self.es = Elasticsearch( self.es = Elasticsearch(
settings.ES["hosts"].split(","), settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs=False, verify_certs=False,
timeout=600 timeout=600
) )
""" """
Database operations Database operations
""" """

def dbType(self) -> str: def dbType(self) -> str:
return "elasticsearch" return "elasticsearch"


""" """
Table operations Table operations
""" """

def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId): if self.indexExist(indexName, knowledgebaseId):
return True return True
""" """
CRUD operations CRUD operations
""" """
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:

def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
""" """
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
""" """
bqry = None bqry = None
vector_similarity_weight = 0.5 vector_similarity_weight = 0.5
for m in matchExprs: for m in matchExprs:
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"] weights = m.fusion_params["weights"]
vector_similarity_weight = float(weights.split(",")[1]) vector_similarity_weight = float(weights.split(",")[1])
for m in matchExprs: for m in matchExprs:
if "minimum_should_match" in m.extra_options: if "minimum_should_match" in m.extra_options:
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%" minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
bqry = Q("bool", bqry = Q("bool",
must=Q("query_string", fields=m.fields,
must=Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text, type="best_fields", query=m.matching_text,
minimum_should_match = minimum_should_match,
minimum_should_match=minimum_should_match,
boost=1), boost=1),
boost = 1.0 - vector_similarity_weight,
)
if condition:
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
boost=1.0 - vector_similarity_weight,
)
elif isinstance(m, MatchDenseExpr): elif isinstance(m, MatchDenseExpr):
assert(bqry is not None)
assert (bqry is not None)
similarity = 0.0 similarity = 0.0
if "similarity" in m.extra_options: if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"] similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name, s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector = list(m.embedding_data),
filter = bqry.to_dict(),
similarity = similarity,
)
if matchExprs:
s.query = bqry
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
)

if condition:
if not bqry:
bqry = Q("bool", must=[])
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")

if bqry:
s = s.query(bqry)
for field in highlightFields: for field in highlightFields:
s = s.highlight(field) s = s.highlight(field)


for field, order in orderBy.fields: for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc" order = "asc" if order == 0 else "desc"
orders.append({field: {"order": order, "unmapped_type": "float", orders.append({field: {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}})
"mode": "avg", "numeric_type": "double"}})
s = s.sort(*orders) s = s.sort(*orders)


if limit > 0: if limit > 0:
s = s[offset:limit] s = s[offset:limit]
q = s.to_dict() q = s.to_dict()
print(json.dumps(q), flush=True)
# logger.info("ESConnection.search [Q]: " + json.dumps(q)) # logger.info("ESConnection.search [Q]: " + json.dumps(q))


for i in range(3): for i in range(3):
for i in range(3): for i in range(3):
try: try:
res = self.es.get(index=(indexName), res = self.es.get(index=(indexName),
id=chunkId, source=True,)
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true": if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.") raise Exception("Es Timeout.")
if not res.get("found"): if not res.get("found"):
for _ in range(100): for _ in range(100):
try: try:
r = self.es.bulk(index=(indexName), operations=operations, r = self.es.bulk(index=(indexName), operations=operations,
refresh=False, timeout="600s")
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE): if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res return res


self.es.update(index=indexName, id=chunkId, doc=doc) self.es.update(index=indexName, id=chunkId, doc=doc)
return True return True
except Exception as e: except Exception as e:
logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
logger.exception(
f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
else: else:
elif isinstance(v, str) or isinstance(v, int): elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v})) bqry.filter.append(Q("term", **{k: v}))
else: else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = [] scripts = []
for k, v in newValue.items(): for k, v in newValue.items():
if not isinstance(k, str) or not v: if not isinstance(k, str) or not v:
elif isinstance(v, int): elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}") scripts.append(f"ctx._source.{k} = {v}")
else: else:
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery( ubq = UpdateByQuery(
index=indexName).using( index=indexName).using(
self.es).query(bqry) self.es).query(bqry)
try: try:
res = self.es.delete_by_query( res = self.es.delete_by_query(
index=indexName, index=indexName,
body = Search().query(qry).to_dict(),
body=Search().query(qry).to_dict(),
refresh=True) refresh=True)
return res["deleted"] return res["deleted"]
except Exception as e: except Exception as e:
return 0 return 0
return 0 return 0



""" """
Helper functions for search result Helper functions for search result
""" """

def getTotal(self, res): def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})): if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"] return res["hits"]["total"]["value"]
continue continue


txt = d["_source"][fieldnm] txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = [] txts = []
for t in re.split(r"[.?!;\n]", txt): for t in re.split(r"[.?!;\n]", txt):
for w in keywords: for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue continue
txts.append(t) txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
bkts = res["aggregations"][agg_field]["buckets"] bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts] return [(b["key"], b["doc_count"]) for b in bkts]



""" """
SQL SQL
""" """

def sql(self, sql: str, fetch_size: int, format: str): def sql(self, sql: str, fetch_size: int, format: str):
logger.info(f"ESConnection.sql get sql: {sql}") logger.info(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql) sql = re.sub(r"[ `]+", " ", sql)
r.group(1), r.group(1),
r.group(2), r.group(2),
r.group(3)), r.group(3)),
match))
match))


for p, r in replaces: for p, r in replaces:
sql = sql.replace(p, r, 1) sql = sql.replace(p, r, 1)


for i in range(3): for i in range(3):
try: try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res return res
except ConnectionTimeout: except ConnectionTimeout:
logger.exception("ESConnection.sql timeout [Q]: " + sql) logger.exception("ESConnection.sql timeout [Q]: " + sql)

正在加载...
取消
保存