您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

document_service.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. import re
  21. import traceback
  22. from concurrent.futures import ThreadPoolExecutor
  23. from copy import deepcopy
  24. from datetime import datetime
  25. from io import BytesIO
  26. from elasticsearch_dsl import Q
  27. from peewee import fn
  28. from api.db.db_utils import bulk_insert_into_db
  29. from api.settings import stat_logger
  30. from api.utils import current_timestamp, get_format_time, get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. from graphrag.mind_map_extractor import MindMapExtractor
  33. from rag.settings import SVR_QUEUE_NAME
  34. from rag.utils.es_conn import ELASTICSEARCH
  35. from rag.utils.minio_conn import MINIO
  36. from rag.nlp import search, rag_tokenizer
  37. from api.db import FileType, TaskStatus, ParserType, LLMType
  38. from api.db.db_models import DB, Knowledgebase, Tenant, Task
  39. from api.db.db_models import Document
  40. from api.db.services.common_service import CommonService
  41. from api.db.services.knowledgebase_service import KnowledgebaseService
  42. from api.db import StatusEnum
  43. from rag.utils.redis_conn import REDIS_CONN
  44. class DocumentService(CommonService):
  45. model = Document
  46. @classmethod
  47. @DB.connection_context()
  48. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  49. orderby, desc, keywords):
  50. if keywords:
  51. docs = cls.model.select().where(
  52. (cls.model.kb_id == kb_id),
  53. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  54. )
  55. else:
  56. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  57. count = docs.count()
  58. if desc:
  59. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  60. else:
  61. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  62. docs = docs.paginate(page_number, items_per_page)
  63. return list(docs.dicts()), count
  64. @classmethod
  65. @DB.connection_context()
  66. def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
  67. if keywords:
  68. docs = cls.model.select().where(
  69. (cls.model.kb_id == dataset_id),
  70. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  71. )
  72. else:
  73. docs = cls.model.select().where(cls.model.kb_id == dataset_id)
  74. total = docs.count()
  75. if descend == 'True':
  76. docs = docs.order_by(cls.model.getter_by(order_by).desc())
  77. if descend == 'False':
  78. docs = docs.order_by(cls.model.getter_by(order_by).asc())
  79. docs = list(docs.dicts())
  80. docs_length = len(docs)
  81. if offset < 0 or offset > docs_length:
  82. raise IndexError("Offset is out of the valid range.")
  83. if count == -1:
  84. return docs[offset:], total
  85. return docs[offset:offset + count], total
  86. @classmethod
  87. @DB.connection_context()
  88. def insert(cls, doc):
  89. if not cls.save(**doc):
  90. raise RuntimeError("Database error (Document)!")
  91. e, doc = cls.get_by_id(doc["id"])
  92. if not e:
  93. raise RuntimeError("Database error (Document retrieval)!")
  94. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  95. if not KnowledgebaseService.update_by_id(
  96. kb.id, {"doc_num": kb.doc_num + 1}):
  97. raise RuntimeError("Database error (Knowledgebase)!")
  98. return doc
  99. @classmethod
  100. @DB.connection_context()
  101. def remove_document(cls, doc, tenant_id):
  102. ELASTICSEARCH.deleteByQuery(
  103. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  104. cls.clear_chunk_num(doc.id)
  105. return cls.delete_by_id(doc.id)
  106. @classmethod
  107. @DB.connection_context()
  108. def get_newly_uploaded(cls):
  109. fields = [
  110. cls.model.id,
  111. cls.model.kb_id,
  112. cls.model.parser_id,
  113. cls.model.parser_config,
  114. cls.model.name,
  115. cls.model.type,
  116. cls.model.location,
  117. cls.model.size,
  118. Knowledgebase.tenant_id,
  119. Tenant.embd_id,
  120. Tenant.img2txt_id,
  121. Tenant.asr_id,
  122. cls.model.update_time]
  123. docs = cls.model.select(*fields) \
  124. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  125. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
  126. .where(
  127. cls.model.status == StatusEnum.VALID.value,
  128. ~(cls.model.type == FileType.VIRTUAL.value),
  129. cls.model.progress == 0,
  130. cls.model.update_time >= current_timestamp() - 1000 * 600,
  131. cls.model.run == TaskStatus.RUNNING.value)\
  132. .order_by(cls.model.update_time.asc())
  133. return list(docs.dicts())
  134. @classmethod
  135. @DB.connection_context()
  136. def get_unfinished_docs(cls):
  137. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
  138. docs = cls.model.select(*fields) \
  139. .where(
  140. cls.model.status == StatusEnum.VALID.value,
  141. ~(cls.model.type == FileType.VIRTUAL.value),
  142. cls.model.progress < 1,
  143. cls.model.progress > 0)
  144. return list(docs.dicts())
  145. @classmethod
  146. @DB.connection_context()
  147. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  148. num = cls.model.update(token_num=cls.model.token_num + token_num,
  149. chunk_num=cls.model.chunk_num + chunk_num,
  150. process_duation=cls.model.process_duation + duation).where(
  151. cls.model.id == doc_id).execute()
  152. if num == 0:
  153. raise LookupError(
  154. "Document not found which is supposed to be there")
  155. num = Knowledgebase.update(
  156. token_num=Knowledgebase.token_num +
  157. token_num,
  158. chunk_num=Knowledgebase.chunk_num +
  159. chunk_num).where(
  160. Knowledgebase.id == kb_id).execute()
  161. return num
  162. @classmethod
  163. @DB.connection_context()
  164. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  165. num = cls.model.update(token_num=cls.model.token_num - token_num,
  166. chunk_num=cls.model.chunk_num - chunk_num,
  167. process_duation=cls.model.process_duation + duation).where(
  168. cls.model.id == doc_id).execute()
  169. if num == 0:
  170. raise LookupError(
  171. "Document not found which is supposed to be there")
  172. num = Knowledgebase.update(
  173. token_num=Knowledgebase.token_num -
  174. token_num,
  175. chunk_num=Knowledgebase.chunk_num -
  176. chunk_num
  177. ).where(
  178. Knowledgebase.id == kb_id).execute()
  179. return num
  180. @classmethod
  181. @DB.connection_context()
  182. def clear_chunk_num(cls, doc_id):
  183. doc = cls.model.get_by_id(doc_id)
  184. assert doc, "Can't fine document in database."
  185. num = Knowledgebase.update(
  186. token_num=Knowledgebase.token_num -
  187. doc.token_num,
  188. chunk_num=Knowledgebase.chunk_num -
  189. doc.chunk_num,
  190. doc_num=Knowledgebase.doc_num-1
  191. ).where(
  192. Knowledgebase.id == doc.kb_id).execute()
  193. return num
  194. @classmethod
  195. @DB.connection_context()
  196. def get_tenant_id(cls, doc_id):
  197. docs = cls.model.select(
  198. Knowledgebase.tenant_id).join(
  199. Knowledgebase, on=(
  200. Knowledgebase.id == cls.model.kb_id)).where(
  201. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  202. docs = docs.dicts()
  203. if not docs:
  204. return
  205. return docs[0]["tenant_id"]
  206. @classmethod
  207. @DB.connection_context()
  208. def get_tenant_id_by_name(cls, name):
  209. docs = cls.model.select(
  210. Knowledgebase.tenant_id).join(
  211. Knowledgebase, on=(
  212. Knowledgebase.id == cls.model.kb_id)).where(
  213. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  214. docs = docs.dicts()
  215. if not docs:
  216. return
  217. return docs[0]["tenant_id"]
  218. @classmethod
  219. @DB.connection_context()
  220. def get_embd_id(cls, doc_id):
  221. docs = cls.model.select(
  222. Knowledgebase.embd_id).join(
  223. Knowledgebase, on=(
  224. Knowledgebase.id == cls.model.kb_id)).where(
  225. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  226. docs = docs.dicts()
  227. if not docs:
  228. return
  229. return docs[0]["embd_id"]
  230. @classmethod
  231. @DB.connection_context()
  232. def get_doc_id_by_doc_name(cls, doc_name):
  233. fields = [cls.model.id]
  234. doc_id = cls.model.select(*fields) \
  235. .where(cls.model.name == doc_name)
  236. doc_id = doc_id.dicts()
  237. if not doc_id:
  238. return
  239. return doc_id[0]["id"]
  240. @classmethod
  241. @DB.connection_context()
  242. def get_thumbnails(cls, docids):
  243. fields = [cls.model.id, cls.model.thumbnail]
  244. return list(cls.model.select(
  245. *fields).where(cls.model.id.in_(docids)).dicts())
  246. @classmethod
  247. @DB.connection_context()
  248. def update_parser_config(cls, id, config):
  249. e, d = cls.get_by_id(id)
  250. if not e:
  251. raise LookupError(f"Document({id}) not found.")
  252. def dfs_update(old, new):
  253. for k, v in new.items():
  254. if k not in old:
  255. old[k] = v
  256. continue
  257. if isinstance(v, dict):
  258. assert isinstance(old[k], dict)
  259. dfs_update(old[k], v)
  260. else:
  261. old[k] = v
  262. dfs_update(d.parser_config, config)
  263. cls.update_by_id(id, {"parser_config": d.parser_config})
  264. @classmethod
  265. @DB.connection_context()
  266. def get_doc_count(cls, tenant_id):
  267. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  268. on=(Knowledgebase.id == cls.model.kb_id)).where(
  269. Knowledgebase.tenant_id == tenant_id)
  270. return len(docs)
  271. @classmethod
  272. @DB.connection_context()
  273. def begin2parse(cls, docid):
  274. cls.update_by_id(
  275. docid, {"progress": random.random() * 1 / 100.,
  276. "progress_msg": "Task dispatched...",
  277. "process_begin_at": get_format_time()
  278. })
  279. @classmethod
  280. @DB.connection_context()
  281. def update_progress(cls):
  282. docs = cls.get_unfinished_docs()
  283. for d in docs:
  284. try:
  285. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  286. if not tsks:
  287. continue
  288. msg = []
  289. prg = 0
  290. finished = True
  291. bad = 0
  292. e, doc = DocumentService.get_by_id(d["id"])
  293. status = doc.run#TaskStatus.RUNNING.value
  294. for t in tsks:
  295. if 0 <= t.progress < 1:
  296. finished = False
  297. prg += t.progress if t.progress >= 0 else 0
  298. if t.progress_msg not in msg:
  299. msg.append(t.progress_msg)
  300. if t.progress == -1:
  301. bad += 1
  302. prg /= len(tsks)
  303. if finished and bad:
  304. prg = -1
  305. status = TaskStatus.FAIL.value
  306. elif finished:
  307. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
  308. queue_raptor_tasks(d)
  309. prg *= 0.98
  310. msg.append("------ RAPTOR -------")
  311. else:
  312. status = TaskStatus.DONE.value
  313. msg = "\n".join(msg)
  314. info = {
  315. "process_duation": datetime.timestamp(
  316. datetime.now()) -
  317. d["process_begin_at"].timestamp(),
  318. "run": status}
  319. if prg != 0:
  320. info["progress"] = prg
  321. if msg:
  322. info["progress_msg"] = msg
  323. cls.update_by_id(d["id"], info)
  324. except Exception as e:
  325. stat_logger.error("fetch task exception:" + str(e))
  326. @classmethod
  327. @DB.connection_context()
  328. def get_kb_doc_count(cls, kb_id):
  329. return len(cls.model.select(cls.model.id).where(
  330. cls.model.kb_id == kb_id).dicts())
  331. @classmethod
  332. @DB.connection_context()
  333. def do_cancel(cls, doc_id):
  334. try:
  335. _, doc = DocumentService.get_by_id(doc_id)
  336. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  337. except Exception as e:
  338. pass
  339. return False
  340. def queue_raptor_tasks(doc):
  341. def new_task():
  342. nonlocal doc
  343. return {
  344. "id": get_uuid(),
  345. "doc_id": doc["id"],
  346. "from_page": 0,
  347. "to_page": -1,
  348. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
  349. }
  350. task = new_task()
  351. bulk_insert_into_db(Task, [task], True)
  352. task["type"] = "raptor"
  353. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  354. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  355. from rag.app import presentation, picture, naive, audio, email
  356. from api.db.services.dialog_service import ConversationService, DialogService
  357. from api.db.services.file_service import FileService
  358. from api.db.services.llm_service import LLMBundle
  359. from api.db.services.user_service import TenantService
  360. from api.db.services.api_service import API4ConversationService
  361. e, conv = ConversationService.get_by_id(conversation_id)
  362. if not e:
  363. e, conv = API4ConversationService.get_by_id(conversation_id)
  364. assert e, "Conversation not found!"
  365. e, dia = DialogService.get_by_id(conv.dialog_id)
  366. kb_id = dia.kb_ids[0]
  367. e, kb = KnowledgebaseService.get_by_id(kb_id)
  368. if not e:
  369. raise LookupError("Can't find this knowledgebase!")
  370. idxnm = search.index_name(kb.tenant_id)
  371. if not ELASTICSEARCH.indexExist(idxnm):
  372. ELASTICSEARCH.createIdx(idxnm, json.load(
  373. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  374. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  375. err, files = FileService.upload_document(kb, file_objs, user_id)
  376. assert not err, "\n".join(err)
  377. def dummy(prog=None, msg=""):
  378. pass
  379. FACTORY = {
  380. ParserType.PRESENTATION.value: presentation,
  381. ParserType.PICTURE.value: picture,
  382. ParserType.AUDIO.value: audio,
  383. ParserType.EMAIL.value: email
  384. }
  385. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  386. exe = ThreadPoolExecutor(max_workers=12)
  387. threads = []
  388. doc_nm = {}
  389. for d, blob in files:
  390. doc_nm[d["id"]] = d["name"]
  391. for d, blob in files:
  392. kwargs = {
  393. "callback": dummy,
  394. "parser_config": parser_config,
  395. "from_page": 0,
  396. "to_page": 100000,
  397. "tenant_id": kb.tenant_id,
  398. "lang": kb.language
  399. }
  400. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  401. for (docinfo, _), th in zip(files, threads):
  402. docs = []
  403. doc = {
  404. "doc_id": docinfo["id"],
  405. "kb_id": [kb.id]
  406. }
  407. for ck in th.result():
  408. d = deepcopy(doc)
  409. d.update(ck)
  410. md5 = hashlib.md5()
  411. md5.update((ck["content_with_weight"] +
  412. str(d["doc_id"])).encode("utf-8"))
  413. d["_id"] = md5.hexdigest()
  414. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  415. d["create_timestamp_flt"] = datetime.now().timestamp()
  416. if not d.get("image"):
  417. docs.append(d)
  418. continue
  419. output_buffer = BytesIO()
  420. if isinstance(d["image"], bytes):
  421. output_buffer = BytesIO(d["image"])
  422. else:
  423. d["image"].save(output_buffer, format='JPEG')
  424. MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
  425. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  426. del d["image"]
  427. docs.append(d)
  428. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  429. docids = [d["id"] for d, _ in files]
  430. chunk_counts = {id: 0 for id in docids}
  431. token_counts = {id: 0 for id in docids}
  432. es_bulk_size = 64
  433. def embedding(doc_id, cnts, batch_size=16):
  434. nonlocal embd_mdl, chunk_counts, token_counts
  435. vects = []
  436. for i in range(0, len(cnts), batch_size):
  437. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  438. vects.extend(vts.tolist())
  439. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  440. token_counts[doc_id] += c
  441. return vects
  442. _, tenant = TenantService.get_by_id(kb.tenant_id)
  443. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  444. for doc_id in docids:
  445. cks = [c for c in docs if c["doc_id"] == doc_id]
  446. if parser_ids[doc_id] != ParserType.PICTURE.value:
  447. mindmap = MindMapExtractor(llm_bdl)
  448. try:
  449. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  450. ensure_ascii=False, indent=2)
  451. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  452. cks.append({
  453. "id": get_uuid(),
  454. "doc_id": doc_id,
  455. "kb_id": [kb.id],
  456. "docnm_kwd": doc_nm[doc_id],
  457. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  458. "content_ltks": "",
  459. "content_with_weight": mind_map,
  460. "knowledge_graph_kwd": "mind_map"
  461. })
  462. except Exception as e:
  463. stat_logger.error("Mind map generation error:", traceback.format_exc())
  464. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  465. assert len(cks) == len(vects)
  466. for i, d in enumerate(cks):
  467. v = vects[i]
  468. d["q_%d_vec" % len(v)] = v
  469. for b in range(0, len(cks), es_bulk_size):
  470. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  471. DocumentService.increment_chunk_num(
  472. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  473. return [d["id"] for d,_ in files]