Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

document_service.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. import re
  21. import traceback
  22. from concurrent.futures import ThreadPoolExecutor
  23. from copy import deepcopy
  24. from datetime import datetime
  25. from io import BytesIO
  26. from elasticsearch_dsl import Q
  27. from peewee import fn
  28. from api.db.db_utils import bulk_insert_into_db
  29. from api.settings import stat_logger
  30. from api.utils import current_timestamp, get_format_time, get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. from graphrag.mind_map_extractor import MindMapExtractor
  33. from rag.settings import SVR_QUEUE_NAME
  34. from rag.utils.es_conn import ELASTICSEARCH
  35. from rag.utils.storage_factory import STORAGE_IMPL
  36. from rag.nlp import search, rag_tokenizer
  37. from api.db import FileType, TaskStatus, ParserType, LLMType
  38. from api.db.db_models import DB, Knowledgebase, Tenant, Task
  39. from api.db.db_models import Document
  40. from api.db.services.common_service import CommonService
  41. from api.db.services.knowledgebase_service import KnowledgebaseService
  42. from api.db import StatusEnum
  43. from rag.utils.redis_conn import REDIS_CONN
  44. class DocumentService(CommonService):
  45. model = Document
  46. @classmethod
  47. @DB.connection_context()
  48. def get_list(cls, kb_id, page_number, items_per_page,
  49. orderby, desc, keywords, id):
  50. docs =cls.model.select().where(cls.model.kb_id==kb_id)
  51. if id:
  52. docs = docs.where(
  53. cls.model.id== id )
  54. if keywords:
  55. docs = docs.where(
  56. fn.LOWER(cls.model.name).contains(keywords.lower())
  57. )
  58. count = docs.count()
  59. if desc:
  60. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  61. else:
  62. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  63. docs = docs.paginate(page_number, items_per_page)
  64. return list(docs.dicts()), count
  65. @classmethod
  66. @DB.connection_context()
  67. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  68. orderby, desc, keywords):
  69. if keywords:
  70. docs = cls.model.select().where(
  71. (cls.model.kb_id == kb_id),
  72. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  73. )
  74. else:
  75. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  76. count = docs.count()
  77. if desc:
  78. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  79. else:
  80. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  81. docs = docs.paginate(page_number, items_per_page)
  82. return list(docs.dicts()), count
  83. @classmethod
  84. @DB.connection_context()
  85. def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
  86. if keywords:
  87. docs = cls.model.select().where(
  88. (cls.model.kb_id == dataset_id),
  89. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  90. )
  91. else:
  92. docs = cls.model.select().where(cls.model.kb_id == dataset_id)
  93. total = docs.count()
  94. if descend == 'True':
  95. docs = docs.order_by(cls.model.getter_by(order_by).desc())
  96. if descend == 'False':
  97. docs = docs.order_by(cls.model.getter_by(order_by).asc())
  98. docs = list(docs.dicts())
  99. docs_length = len(docs)
  100. if offset < 0 or offset > docs_length:
  101. raise IndexError("Offset is out of the valid range.")
  102. if count == -1:
  103. return docs[offset:], total
  104. return docs[offset:offset + count], total
  105. @classmethod
  106. @DB.connection_context()
  107. def insert(cls, doc):
  108. if not cls.save(**doc):
  109. raise RuntimeError("Database error (Document)!")
  110. e, doc = cls.get_by_id(doc["id"])
  111. if not e:
  112. raise RuntimeError("Database error (Document retrieval)!")
  113. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  114. if not KnowledgebaseService.update_by_id(
  115. kb.id, {"doc_num": kb.doc_num + 1}):
  116. raise RuntimeError("Database error (Knowledgebase)!")
  117. return doc
  118. @classmethod
  119. @DB.connection_context()
  120. def remove_document(cls, doc, tenant_id):
  121. ELASTICSEARCH.deleteByQuery(
  122. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  123. cls.clear_chunk_num(doc.id)
  124. return cls.delete_by_id(doc.id)
  125. @classmethod
  126. @DB.connection_context()
  127. def get_newly_uploaded(cls):
  128. fields = [
  129. cls.model.id,
  130. cls.model.kb_id,
  131. cls.model.parser_id,
  132. cls.model.parser_config,
  133. cls.model.name,
  134. cls.model.type,
  135. cls.model.location,
  136. cls.model.size,
  137. Knowledgebase.tenant_id,
  138. Tenant.embd_id,
  139. Tenant.img2txt_id,
  140. Tenant.asr_id,
  141. cls.model.update_time]
  142. docs = cls.model.select(*fields) \
  143. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  144. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
  145. .where(
  146. cls.model.status == StatusEnum.VALID.value,
  147. ~(cls.model.type == FileType.VIRTUAL.value),
  148. cls.model.progress == 0,
  149. cls.model.update_time >= current_timestamp() - 1000 * 600,
  150. cls.model.run == TaskStatus.RUNNING.value)\
  151. .order_by(cls.model.update_time.asc())
  152. return list(docs.dicts())
  153. @classmethod
  154. @DB.connection_context()
  155. def get_unfinished_docs(cls):
  156. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
  157. docs = cls.model.select(*fields) \
  158. .where(
  159. cls.model.status == StatusEnum.VALID.value,
  160. ~(cls.model.type == FileType.VIRTUAL.value),
  161. cls.model.progress < 1,
  162. cls.model.progress > 0)
  163. return list(docs.dicts())
  164. @classmethod
  165. @DB.connection_context()
  166. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  167. num = cls.model.update(token_num=cls.model.token_num + token_num,
  168. chunk_num=cls.model.chunk_num + chunk_num,
  169. process_duation=cls.model.process_duation + duation).where(
  170. cls.model.id == doc_id).execute()
  171. if num == 0:
  172. raise LookupError(
  173. "Document not found which is supposed to be there")
  174. num = Knowledgebase.update(
  175. token_num=Knowledgebase.token_num +
  176. token_num,
  177. chunk_num=Knowledgebase.chunk_num +
  178. chunk_num).where(
  179. Knowledgebase.id == kb_id).execute()
  180. return num
  181. @classmethod
  182. @DB.connection_context()
  183. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  184. num = cls.model.update(token_num=cls.model.token_num - token_num,
  185. chunk_num=cls.model.chunk_num - chunk_num,
  186. process_duation=cls.model.process_duation + duation).where(
  187. cls.model.id == doc_id).execute()
  188. if num == 0:
  189. raise LookupError(
  190. "Document not found which is supposed to be there")
  191. num = Knowledgebase.update(
  192. token_num=Knowledgebase.token_num -
  193. token_num,
  194. chunk_num=Knowledgebase.chunk_num -
  195. chunk_num
  196. ).where(
  197. Knowledgebase.id == kb_id).execute()
  198. return num
  199. @classmethod
  200. @DB.connection_context()
  201. def clear_chunk_num(cls, doc_id):
  202. doc = cls.model.get_by_id(doc_id)
  203. assert doc, "Can't fine document in database."
  204. num = Knowledgebase.update(
  205. token_num=Knowledgebase.token_num -
  206. doc.token_num,
  207. chunk_num=Knowledgebase.chunk_num -
  208. doc.chunk_num,
  209. doc_num=Knowledgebase.doc_num-1
  210. ).where(
  211. Knowledgebase.id == doc.kb_id).execute()
  212. return num
  213. @classmethod
  214. @DB.connection_context()
  215. def get_tenant_id(cls, doc_id):
  216. docs = cls.model.select(
  217. Knowledgebase.tenant_id).join(
  218. Knowledgebase, on=(
  219. Knowledgebase.id == cls.model.kb_id)).where(
  220. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  221. docs = docs.dicts()
  222. if not docs:
  223. return
  224. return docs[0]["tenant_id"]
  225. @classmethod
  226. @DB.connection_context()
  227. def get_tenant_id_by_name(cls, name):
  228. docs = cls.model.select(
  229. Knowledgebase.tenant_id).join(
  230. Knowledgebase, on=(
  231. Knowledgebase.id == cls.model.kb_id)).where(
  232. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  233. docs = docs.dicts()
  234. if not docs:
  235. return
  236. return docs[0]["tenant_id"]
  237. @classmethod
  238. @DB.connection_context()
  239. def get_embd_id(cls, doc_id):
  240. docs = cls.model.select(
  241. Knowledgebase.embd_id).join(
  242. Knowledgebase, on=(
  243. Knowledgebase.id == cls.model.kb_id)).where(
  244. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  245. docs = docs.dicts()
  246. if not docs:
  247. return
  248. return docs[0]["embd_id"]
  249. @classmethod
  250. @DB.connection_context()
  251. def get_doc_id_by_doc_name(cls, doc_name):
  252. fields = [cls.model.id]
  253. doc_id = cls.model.select(*fields) \
  254. .where(cls.model.name == doc_name)
  255. doc_id = doc_id.dicts()
  256. if not doc_id:
  257. return
  258. return doc_id[0]["id"]
  259. @classmethod
  260. @DB.connection_context()
  261. def get_thumbnails(cls, docids):
  262. fields = [cls.model.id, cls.model.thumbnail]
  263. return list(cls.model.select(
  264. *fields).where(cls.model.id.in_(docids)).dicts())
  265. @classmethod
  266. @DB.connection_context()
  267. def update_parser_config(cls, id, config):
  268. e, d = cls.get_by_id(id)
  269. if not e:
  270. raise LookupError(f"Document({id}) not found.")
  271. def dfs_update(old, new):
  272. for k, v in new.items():
  273. if k not in old:
  274. old[k] = v
  275. continue
  276. if isinstance(v, dict):
  277. assert isinstance(old[k], dict)
  278. dfs_update(old[k], v)
  279. else:
  280. old[k] = v
  281. dfs_update(d.parser_config, config)
  282. cls.update_by_id(id, {"parser_config": d.parser_config})
  283. @classmethod
  284. @DB.connection_context()
  285. def get_doc_count(cls, tenant_id):
  286. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  287. on=(Knowledgebase.id == cls.model.kb_id)).where(
  288. Knowledgebase.tenant_id == tenant_id)
  289. return len(docs)
  290. @classmethod
  291. @DB.connection_context()
  292. def begin2parse(cls, docid):
  293. cls.update_by_id(
  294. docid, {"progress": random.random() * 1 / 100.,
  295. "progress_msg": "Task dispatched...",
  296. "process_begin_at": get_format_time()
  297. })
  298. @classmethod
  299. @DB.connection_context()
  300. def update_progress(cls):
  301. docs = cls.get_unfinished_docs()
  302. for d in docs:
  303. try:
  304. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  305. if not tsks:
  306. continue
  307. msg = []
  308. prg = 0
  309. finished = True
  310. bad = 0
  311. e, doc = DocumentService.get_by_id(d["id"])
  312. status = doc.run#TaskStatus.RUNNING.value
  313. for t in tsks:
  314. if 0 <= t.progress < 1:
  315. finished = False
  316. prg += t.progress if t.progress >= 0 else 0
  317. if t.progress_msg not in msg:
  318. msg.append(t.progress_msg)
  319. if t.progress == -1:
  320. bad += 1
  321. prg /= len(tsks)
  322. if finished and bad:
  323. prg = -1
  324. status = TaskStatus.FAIL.value
  325. elif finished:
  326. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
  327. queue_raptor_tasks(d)
  328. prg = 0.98 * len(tsks)/(len(tsks)+1)
  329. msg.append("------ RAPTOR -------")
  330. else:
  331. status = TaskStatus.DONE.value
  332. msg = "\n".join(msg)
  333. info = {
  334. "process_duation": datetime.timestamp(
  335. datetime.now()) -
  336. d["process_begin_at"].timestamp(),
  337. "run": status}
  338. if prg != 0:
  339. info["progress"] = prg
  340. if msg:
  341. info["progress_msg"] = msg
  342. cls.update_by_id(d["id"], info)
  343. except Exception as e:
  344. if str(e).find("'0'") < 0:
  345. stat_logger.error("fetch task exception:" + str(e))
  346. @classmethod
  347. @DB.connection_context()
  348. def get_kb_doc_count(cls, kb_id):
  349. return len(cls.model.select(cls.model.id).where(
  350. cls.model.kb_id == kb_id).dicts())
  351. @classmethod
  352. @DB.connection_context()
  353. def do_cancel(cls, doc_id):
  354. try:
  355. _, doc = DocumentService.get_by_id(doc_id)
  356. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  357. except Exception as e:
  358. pass
  359. return False
  360. def queue_raptor_tasks(doc):
  361. def new_task():
  362. nonlocal doc
  363. return {
  364. "id": get_uuid(),
  365. "doc_id": doc["id"],
  366. "from_page": 0,
  367. "to_page": -1,
  368. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
  369. }
  370. task = new_task()
  371. bulk_insert_into_db(Task, [task], True)
  372. task["type"] = "raptor"
  373. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  374. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  375. from rag.app import presentation, picture, naive, audio, email
  376. from api.db.services.dialog_service import ConversationService, DialogService
  377. from api.db.services.file_service import FileService
  378. from api.db.services.llm_service import LLMBundle
  379. from api.db.services.user_service import TenantService
  380. from api.db.services.api_service import API4ConversationService
  381. e, conv = ConversationService.get_by_id(conversation_id)
  382. if not e:
  383. e, conv = API4ConversationService.get_by_id(conversation_id)
  384. assert e, "Conversation not found!"
  385. e, dia = DialogService.get_by_id(conv.dialog_id)
  386. kb_id = dia.kb_ids[0]
  387. e, kb = KnowledgebaseService.get_by_id(kb_id)
  388. if not e:
  389. raise LookupError("Can't find this knowledgebase!")
  390. idxnm = search.index_name(kb.tenant_id)
  391. if not ELASTICSEARCH.indexExist(idxnm):
  392. ELASTICSEARCH.createIdx(idxnm, json.load(
  393. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  394. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  395. err, files = FileService.upload_document(kb, file_objs, user_id)
  396. assert not err, "\n".join(err)
  397. def dummy(prog=None, msg=""):
  398. pass
  399. FACTORY = {
  400. ParserType.PRESENTATION.value: presentation,
  401. ParserType.PICTURE.value: picture,
  402. ParserType.AUDIO.value: audio,
  403. ParserType.EMAIL.value: email
  404. }
  405. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  406. exe = ThreadPoolExecutor(max_workers=12)
  407. threads = []
  408. doc_nm = {}
  409. for d, blob in files:
  410. doc_nm[d["id"]] = d["name"]
  411. for d, blob in files:
  412. kwargs = {
  413. "callback": dummy,
  414. "parser_config": parser_config,
  415. "from_page": 0,
  416. "to_page": 100000,
  417. "tenant_id": kb.tenant_id,
  418. "lang": kb.language
  419. }
  420. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  421. for (docinfo, _), th in zip(files, threads):
  422. docs = []
  423. doc = {
  424. "doc_id": docinfo["id"],
  425. "kb_id": [kb.id]
  426. }
  427. for ck in th.result():
  428. d = deepcopy(doc)
  429. d.update(ck)
  430. md5 = hashlib.md5()
  431. md5.update((ck["content_with_weight"] +
  432. str(d["doc_id"])).encode("utf-8"))
  433. d["_id"] = md5.hexdigest()
  434. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  435. d["create_timestamp_flt"] = datetime.now().timestamp()
  436. if not d.get("image"):
  437. docs.append(d)
  438. continue
  439. output_buffer = BytesIO()
  440. if isinstance(d["image"], bytes):
  441. output_buffer = BytesIO(d["image"])
  442. else:
  443. d["image"].save(output_buffer, format='JPEG')
  444. STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
  445. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  446. del d["image"]
  447. docs.append(d)
  448. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  449. docids = [d["id"] for d, _ in files]
  450. chunk_counts = {id: 0 for id in docids}
  451. token_counts = {id: 0 for id in docids}
  452. es_bulk_size = 64
  453. def embedding(doc_id, cnts, batch_size=16):
  454. nonlocal embd_mdl, chunk_counts, token_counts
  455. vects = []
  456. for i in range(0, len(cnts), batch_size):
  457. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  458. vects.extend(vts.tolist())
  459. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  460. token_counts[doc_id] += c
  461. return vects
  462. _, tenant = TenantService.get_by_id(kb.tenant_id)
  463. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  464. for doc_id in docids:
  465. cks = [c for c in docs if c["doc_id"] == doc_id]
  466. if parser_ids[doc_id] != ParserType.PICTURE.value:
  467. mindmap = MindMapExtractor(llm_bdl)
  468. try:
  469. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  470. ensure_ascii=False, indent=2)
  471. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  472. cks.append({
  473. "id": get_uuid(),
  474. "doc_id": doc_id,
  475. "kb_id": [kb.id],
  476. "docnm_kwd": doc_nm[doc_id],
  477. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  478. "content_ltks": "",
  479. "content_with_weight": mind_map,
  480. "knowledge_graph_kwd": "mind_map"
  481. })
  482. except Exception as e:
  483. stat_logger.error("Mind map generation error:", traceback.format_exc())
  484. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  485. assert len(cks) == len(vects)
  486. for i, d in enumerate(cks):
  487. v = vects[i]
  488. d["q_%d_vec" % len(v)] = v
  489. for b in range(0, len(cks), es_bulk_size):
  490. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  491. DocumentService.increment_chunk_num(
  492. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  493. return [d["id"] for d,_ in files]