您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

document_service.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. import re
  21. import traceback
  22. from concurrent.futures import ThreadPoolExecutor
  23. from copy import deepcopy
  24. from datetime import datetime
  25. from io import BytesIO
  26. from elasticsearch_dsl import Q
  27. from peewee import fn
  28. from api.db.db_utils import bulk_insert_into_db
  29. from api.settings import stat_logger
  30. from api.utils import current_timestamp, get_format_time, get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. from graphrag.mind_map_extractor import MindMapExtractor
  33. from rag.settings import SVR_QUEUE_NAME
  34. from rag.utils.es_conn import ELASTICSEARCH
  35. from rag.utils.storage_factory import STORAGE_IMPL
  36. from rag.nlp import search, rag_tokenizer
  37. from api.db import FileType, TaskStatus, ParserType, LLMType
  38. from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
  39. from api.db.db_models import Document
  40. from api.db.services.common_service import CommonService
  41. from api.db.services.knowledgebase_service import KnowledgebaseService
  42. from api.db import StatusEnum
  43. from rag.utils.redis_conn import REDIS_CONN
  44. class DocumentService(CommonService):
  45. model = Document
  46. @classmethod
  47. @DB.connection_context()
  48. def get_list(cls, kb_id, page_number, items_per_page,
  49. orderby, desc, keywords, id, name):
  50. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  51. if id:
  52. docs = docs.where(
  53. cls.model.id == id)
  54. if name:
  55. docs = docs.where(
  56. cls.model.name == name
  57. )
  58. if keywords:
  59. docs = docs.where(
  60. fn.LOWER(cls.model.name).contains(keywords.lower())
  61. )
  62. if desc:
  63. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  64. else:
  65. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  66. docs = docs.paginate(page_number, items_per_page)
  67. count = docs.count()
  68. return list(docs.dicts()), count
  69. @classmethod
  70. @DB.connection_context()
  71. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  72. orderby, desc, keywords):
  73. if keywords:
  74. docs = cls.model.select().where(
  75. (cls.model.kb_id == kb_id),
  76. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  77. )
  78. else:
  79. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  80. count = docs.count()
  81. if desc:
  82. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  83. else:
  84. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  85. docs = docs.paginate(page_number, items_per_page)
  86. return list(docs.dicts()), count
  87. @classmethod
  88. @DB.connection_context()
  89. def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
  90. if keywords:
  91. docs = cls.model.select().where(
  92. (cls.model.kb_id == dataset_id),
  93. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  94. )
  95. else:
  96. docs = cls.model.select().where(cls.model.kb_id == dataset_id)
  97. total = docs.count()
  98. if descend == 'True':
  99. docs = docs.order_by(cls.model.getter_by(order_by).desc())
  100. if descend == 'False':
  101. docs = docs.order_by(cls.model.getter_by(order_by).asc())
  102. docs = list(docs.dicts())
  103. docs_length = len(docs)
  104. if offset < 0 or offset > docs_length:
  105. raise IndexError("Offset is out of the valid range.")
  106. if count == -1:
  107. return docs[offset:], total
  108. return docs[offset:offset + count], total
  109. @classmethod
  110. @DB.connection_context()
  111. def insert(cls, doc):
  112. if not cls.save(**doc):
  113. raise RuntimeError("Database error (Document)!")
  114. e, doc = cls.get_by_id(doc["id"])
  115. if not e:
  116. raise RuntimeError("Database error (Document retrieval)!")
  117. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  118. if not KnowledgebaseService.update_by_id(
  119. kb.id, {"doc_num": kb.doc_num + 1}):
  120. raise RuntimeError("Database error (Knowledgebase)!")
  121. return doc
  122. @classmethod
  123. @DB.connection_context()
  124. def remove_document(cls, doc, tenant_id):
  125. ELASTICSEARCH.deleteByQuery(
  126. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  127. cls.clear_chunk_num(doc.id)
  128. return cls.delete_by_id(doc.id)
  129. @classmethod
  130. @DB.connection_context()
  131. def get_newly_uploaded(cls):
  132. fields = [
  133. cls.model.id,
  134. cls.model.kb_id,
  135. cls.model.parser_id,
  136. cls.model.parser_config,
  137. cls.model.name,
  138. cls.model.type,
  139. cls.model.location,
  140. cls.model.size,
  141. Knowledgebase.tenant_id,
  142. Tenant.embd_id,
  143. Tenant.img2txt_id,
  144. Tenant.asr_id,
  145. cls.model.update_time]
  146. docs = cls.model.select(*fields) \
  147. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  148. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  149. .where(
  150. cls.model.status == StatusEnum.VALID.value,
  151. ~(cls.model.type == FileType.VIRTUAL.value),
  152. cls.model.progress == 0,
  153. cls.model.update_time >= current_timestamp() - 1000 * 600,
  154. cls.model.run == TaskStatus.RUNNING.value) \
  155. .order_by(cls.model.update_time.asc())
  156. return list(docs.dicts())
  157. @classmethod
  158. @DB.connection_context()
  159. def get_unfinished_docs(cls):
  160. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  161. cls.model.run]
  162. docs = cls.model.select(*fields) \
  163. .where(
  164. cls.model.status == StatusEnum.VALID.value,
  165. ~(cls.model.type == FileType.VIRTUAL.value),
  166. cls.model.progress < 1,
  167. cls.model.progress > 0)
  168. return list(docs.dicts())
  169. @classmethod
  170. @DB.connection_context()
  171. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  172. num = cls.model.update(token_num=cls.model.token_num + token_num,
  173. chunk_num=cls.model.chunk_num + chunk_num,
  174. process_duation=cls.model.process_duation + duation).where(
  175. cls.model.id == doc_id).execute()
  176. if num == 0:
  177. raise LookupError(
  178. "Document not found which is supposed to be there")
  179. num = Knowledgebase.update(
  180. token_num=Knowledgebase.token_num +
  181. token_num,
  182. chunk_num=Knowledgebase.chunk_num +
  183. chunk_num).where(
  184. Knowledgebase.id == kb_id).execute()
  185. return num
  186. @classmethod
  187. @DB.connection_context()
  188. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  189. num = cls.model.update(token_num=cls.model.token_num - token_num,
  190. chunk_num=cls.model.chunk_num - chunk_num,
  191. process_duation=cls.model.process_duation + duation).where(
  192. cls.model.id == doc_id).execute()
  193. if num == 0:
  194. raise LookupError(
  195. "Document not found which is supposed to be there")
  196. num = Knowledgebase.update(
  197. token_num=Knowledgebase.token_num -
  198. token_num,
  199. chunk_num=Knowledgebase.chunk_num -
  200. chunk_num
  201. ).where(
  202. Knowledgebase.id == kb_id).execute()
  203. return num
  204. @classmethod
  205. @DB.connection_context()
  206. def clear_chunk_num(cls, doc_id):
  207. doc = cls.model.get_by_id(doc_id)
  208. assert doc, "Can't fine document in database."
  209. num = Knowledgebase.update(
  210. token_num=Knowledgebase.token_num -
  211. doc.token_num,
  212. chunk_num=Knowledgebase.chunk_num -
  213. doc.chunk_num,
  214. doc_num=Knowledgebase.doc_num - 1
  215. ).where(
  216. Knowledgebase.id == doc.kb_id).execute()
  217. return num
  218. @classmethod
  219. @DB.connection_context()
  220. def get_tenant_id(cls, doc_id):
  221. docs = cls.model.select(
  222. Knowledgebase.tenant_id).join(
  223. Knowledgebase, on=(
  224. Knowledgebase.id == cls.model.kb_id)).where(
  225. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  226. docs = docs.dicts()
  227. if not docs:
  228. return
  229. return docs[0]["tenant_id"]
  230. @classmethod
  231. @DB.connection_context()
  232. def get_tenant_id_by_name(cls, name):
  233. docs = cls.model.select(
  234. Knowledgebase.tenant_id).join(
  235. Knowledgebase, on=(
  236. Knowledgebase.id == cls.model.kb_id)).where(
  237. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  238. docs = docs.dicts()
  239. if not docs:
  240. return
  241. return docs[0]["tenant_id"]
  242. @classmethod
  243. @DB.connection_context()
  244. def accessible(cls, doc_id, user_id):
  245. docs = cls.model.select(
  246. cls.model.id).join(
  247. Knowledgebase, on=(
  248. Knowledgebase.id == cls.model.kb_id)
  249. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  250. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  251. docs = docs.dicts()
  252. if not docs:
  253. return False
  254. return True
  255. @classmethod
  256. @DB.connection_context()
  257. def accessible4deletion(cls, doc_id, user_id):
  258. docs = cls.model.select(
  259. cls.model.id).join(
  260. Knowledgebase, on=(
  261. Knowledgebase.id == cls.model.kb_id)
  262. ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
  263. docs = docs.dicts()
  264. if not docs:
  265. return False
  266. return True
  267. @classmethod
  268. @DB.connection_context()
  269. def get_embd_id(cls, doc_id):
  270. docs = cls.model.select(
  271. Knowledgebase.embd_id).join(
  272. Knowledgebase, on=(
  273. Knowledgebase.id == cls.model.kb_id)).where(
  274. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  275. docs = docs.dicts()
  276. if not docs:
  277. return
  278. return docs[0]["embd_id"]
  279. @classmethod
  280. @DB.connection_context()
  281. def get_doc_id_by_doc_name(cls, doc_name):
  282. fields = [cls.model.id]
  283. doc_id = cls.model.select(*fields) \
  284. .where(cls.model.name == doc_name)
  285. doc_id = doc_id.dicts()
  286. if not doc_id:
  287. return
  288. return doc_id[0]["id"]
  289. @classmethod
  290. @DB.connection_context()
  291. def get_thumbnails(cls, docids):
  292. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  293. return list(cls.model.select(
  294. *fields).where(cls.model.id.in_(docids)).dicts())
  295. @classmethod
  296. @DB.connection_context()
  297. def update_parser_config(cls, id, config):
  298. e, d = cls.get_by_id(id)
  299. if not e:
  300. raise LookupError(f"Document({id}) not found.")
  301. def dfs_update(old, new):
  302. for k, v in new.items():
  303. if k not in old:
  304. old[k] = v
  305. continue
  306. if isinstance(v, dict):
  307. assert isinstance(old[k], dict)
  308. dfs_update(old[k], v)
  309. else:
  310. old[k] = v
  311. dfs_update(d.parser_config, config)
  312. cls.update_by_id(id, {"parser_config": d.parser_config})
  313. @classmethod
  314. @DB.connection_context()
  315. def get_doc_count(cls, tenant_id):
  316. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  317. on=(Knowledgebase.id == cls.model.kb_id)).where(
  318. Knowledgebase.tenant_id == tenant_id)
  319. return len(docs)
  320. @classmethod
  321. @DB.connection_context()
  322. def begin2parse(cls, docid):
  323. cls.update_by_id(
  324. docid, {"progress": random.random() * 1 / 100.,
  325. "progress_msg": "Task dispatched...",
  326. "process_begin_at": get_format_time()
  327. })
  328. @classmethod
  329. @DB.connection_context()
  330. def update_progress(cls):
  331. docs = cls.get_unfinished_docs()
  332. for d in docs:
  333. try:
  334. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  335. if not tsks:
  336. continue
  337. msg = []
  338. prg = 0
  339. finished = True
  340. bad = 0
  341. e, doc = DocumentService.get_by_id(d["id"])
  342. status = doc.run # TaskStatus.RUNNING.value
  343. for t in tsks:
  344. if 0 <= t.progress < 1:
  345. finished = False
  346. prg += t.progress if t.progress >= 0 else 0
  347. if t.progress_msg not in msg:
  348. msg.append(t.progress_msg)
  349. if t.progress == -1:
  350. bad += 1
  351. prg /= len(tsks)
  352. if finished and bad:
  353. prg = -1
  354. status = TaskStatus.FAIL.value
  355. elif finished:
  356. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
  357. " raptor") < 0:
  358. queue_raptor_tasks(d)
  359. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  360. msg.append("------ RAPTOR -------")
  361. else:
  362. status = TaskStatus.DONE.value
  363. msg = "\n".join(msg)
  364. info = {
  365. "process_duation": datetime.timestamp(
  366. datetime.now()) -
  367. d["process_begin_at"].timestamp(),
  368. "run": status}
  369. if prg != 0:
  370. info["progress"] = prg
  371. if msg:
  372. info["progress_msg"] = msg
  373. cls.update_by_id(d["id"], info)
  374. except Exception as e:
  375. if str(e).find("'0'") < 0:
  376. stat_logger.error("fetch task exception:" + str(e))
  377. @classmethod
  378. @DB.connection_context()
  379. def get_kb_doc_count(cls, kb_id):
  380. return len(cls.model.select(cls.model.id).where(
  381. cls.model.kb_id == kb_id).dicts())
  382. @classmethod
  383. @DB.connection_context()
  384. def do_cancel(cls, doc_id):
  385. try:
  386. _, doc = DocumentService.get_by_id(doc_id)
  387. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  388. except Exception as e:
  389. pass
  390. return False
  391. def queue_raptor_tasks(doc):
  392. def new_task():
  393. nonlocal doc
  394. return {
  395. "id": get_uuid(),
  396. "doc_id": doc["id"],
  397. "from_page": 0,
  398. "to_page": -1,
  399. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
  400. }
  401. task = new_task()
  402. bulk_insert_into_db(Task, [task], True)
  403. task["type"] = "raptor"
  404. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  405. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  406. from rag.app import presentation, picture, naive, audio, email
  407. from api.db.services.dialog_service import ConversationService, DialogService
  408. from api.db.services.file_service import FileService
  409. from api.db.services.llm_service import LLMBundle
  410. from api.db.services.user_service import TenantService
  411. from api.db.services.api_service import API4ConversationService
  412. e, conv = ConversationService.get_by_id(conversation_id)
  413. if not e:
  414. e, conv = API4ConversationService.get_by_id(conversation_id)
  415. assert e, "Conversation not found!"
  416. e, dia = DialogService.get_by_id(conv.dialog_id)
  417. kb_id = dia.kb_ids[0]
  418. e, kb = KnowledgebaseService.get_by_id(kb_id)
  419. if not e:
  420. raise LookupError("Can't find this knowledgebase!")
  421. idxnm = search.index_name(kb.tenant_id)
  422. if not ELASTICSEARCH.indexExist(idxnm):
  423. ELASTICSEARCH.createIdx(idxnm, json.load(
  424. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  425. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  426. err, files = FileService.upload_document(kb, file_objs, user_id)
  427. assert not err, "\n".join(err)
  428. def dummy(prog=None, msg=""):
  429. pass
  430. FACTORY = {
  431. ParserType.PRESENTATION.value: presentation,
  432. ParserType.PICTURE.value: picture,
  433. ParserType.AUDIO.value: audio,
  434. ParserType.EMAIL.value: email
  435. }
  436. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  437. exe = ThreadPoolExecutor(max_workers=12)
  438. threads = []
  439. doc_nm = {}
  440. for d, blob in files:
  441. doc_nm[d["id"]] = d["name"]
  442. for d, blob in files:
  443. kwargs = {
  444. "callback": dummy,
  445. "parser_config": parser_config,
  446. "from_page": 0,
  447. "to_page": 100000,
  448. "tenant_id": kb.tenant_id,
  449. "lang": kb.language
  450. }
  451. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  452. for (docinfo, _), th in zip(files, threads):
  453. docs = []
  454. doc = {
  455. "doc_id": docinfo["id"],
  456. "kb_id": [kb.id]
  457. }
  458. for ck in th.result():
  459. d = deepcopy(doc)
  460. d.update(ck)
  461. md5 = hashlib.md5()
  462. md5.update((ck["content_with_weight"] +
  463. str(d["doc_id"])).encode("utf-8"))
  464. d["_id"] = md5.hexdigest()
  465. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  466. d["create_timestamp_flt"] = datetime.now().timestamp()
  467. if not d.get("image"):
  468. docs.append(d)
  469. continue
  470. output_buffer = BytesIO()
  471. if isinstance(d["image"], bytes):
  472. output_buffer = BytesIO(d["image"])
  473. else:
  474. d["image"].save(output_buffer, format='JPEG')
  475. STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
  476. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  477. del d["image"]
  478. docs.append(d)
  479. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  480. docids = [d["id"] for d, _ in files]
  481. chunk_counts = {id: 0 for id in docids}
  482. token_counts = {id: 0 for id in docids}
  483. es_bulk_size = 64
  484. def embedding(doc_id, cnts, batch_size=16):
  485. nonlocal embd_mdl, chunk_counts, token_counts
  486. vects = []
  487. for i in range(0, len(cnts), batch_size):
  488. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  489. vects.extend(vts.tolist())
  490. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  491. token_counts[doc_id] += c
  492. return vects
  493. _, tenant = TenantService.get_by_id(kb.tenant_id)
  494. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  495. for doc_id in docids:
  496. cks = [c for c in docs if c["doc_id"] == doc_id]
  497. if parser_ids[doc_id] != ParserType.PICTURE.value:
  498. mindmap = MindMapExtractor(llm_bdl)
  499. try:
  500. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  501. ensure_ascii=False, indent=2)
  502. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  503. cks.append({
  504. "id": get_uuid(),
  505. "doc_id": doc_id,
  506. "kb_id": [kb.id],
  507. "docnm_kwd": doc_nm[doc_id],
  508. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  509. "content_ltks": "",
  510. "content_with_weight": mind_map,
  511. "knowledge_graph_kwd": "mind_map"
  512. })
  513. except Exception as e:
  514. stat_logger.error("Mind map generation error:", traceback.format_exc())
  515. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  516. assert len(cks) == len(vects)
  517. for i, d in enumerate(cks):
  518. v = vects[i]
  519. d["q_%d_vec" % len(v)] = v
  520. for b in range(0, len(cks), es_bulk_size):
  521. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  522. DocumentService.increment_chunk_num(
  523. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  524. return [d["id"] for d, _ in files]