You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from datetime import datetime
  23. from io import BytesIO
  24. from elasticsearch_dsl import Q
  25. from peewee import fn
  26. from api.db.db_utils import bulk_insert_into_db
  27. from api.settings import stat_logger
  28. from api.utils import current_timestamp, get_format_time, get_uuid
  29. from api.utils.file_utils import get_project_base_directory
  30. from graphrag.mind_map_extractor import MindMapExtractor
  31. from rag.settings import SVR_QUEUE_NAME
  32. from rag.utils.es_conn import ELASTICSEARCH
  33. from rag.utils.minio_conn import MINIO
  34. from rag.nlp import search
  35. from api.db import FileType, TaskStatus, ParserType, LLMType
  36. from api.db.db_models import DB, Knowledgebase, Tenant, Task
  37. from api.db.db_models import Document
  38. from api.db.services.common_service import CommonService
  39. from api.db.services.knowledgebase_service import KnowledgebaseService
  40. from api.db import StatusEnum
  41. from rag.utils.redis_conn import REDIS_CONN
  42. class DocumentService(CommonService):
  43. model = Document
  44. @classmethod
  45. @DB.connection_context()
  46. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  47. orderby, desc, keywords):
  48. if keywords:
  49. docs = cls.model.select().where(
  50. (cls.model.kb_id == kb_id),
  51. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  52. )
  53. else:
  54. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  55. count = docs.count()
  56. if desc:
  57. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  58. else:
  59. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  60. docs = docs.paginate(page_number, items_per_page)
  61. return list(docs.dicts()), count
  62. @classmethod
  63. @DB.connection_context()
  64. def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
  65. if keywords:
  66. docs = cls.model.select().where(
  67. (cls.model.kb_id == dataset_id),
  68. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  69. )
  70. else:
  71. docs = cls.model.select().where(cls.model.kb_id == dataset_id)
  72. total = docs.count()
  73. if descend == 'True':
  74. docs = docs.order_by(cls.model.getter_by(order_by).desc())
  75. if descend == 'False':
  76. docs = docs.order_by(cls.model.getter_by(order_by).asc())
  77. docs = list(docs.dicts())
  78. docs_length = len(docs)
  79. if offset < 0 or offset > docs_length:
  80. raise IndexError("Offset is out of the valid range.")
  81. if count == -1:
  82. return docs[offset:], total
  83. return docs[offset:offset + count], total
  84. @classmethod
  85. @DB.connection_context()
  86. def insert(cls, doc):
  87. if not cls.save(**doc):
  88. raise RuntimeError("Database error (Document)!")
  89. e, doc = cls.get_by_id(doc["id"])
  90. if not e:
  91. raise RuntimeError("Database error (Document retrieval)!")
  92. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  93. if not KnowledgebaseService.update_by_id(
  94. kb.id, {"doc_num": kb.doc_num + 1}):
  95. raise RuntimeError("Database error (Knowledgebase)!")
  96. return doc
  97. @classmethod
  98. @DB.connection_context()
  99. def remove_document(cls, doc, tenant_id):
  100. ELASTICSEARCH.deleteByQuery(
  101. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  102. cls.clear_chunk_num(doc.id)
  103. return cls.delete_by_id(doc.id)
  104. @classmethod
  105. @DB.connection_context()
  106. def get_newly_uploaded(cls):
  107. fields = [
  108. cls.model.id,
  109. cls.model.kb_id,
  110. cls.model.parser_id,
  111. cls.model.parser_config,
  112. cls.model.name,
  113. cls.model.type,
  114. cls.model.location,
  115. cls.model.size,
  116. Knowledgebase.tenant_id,
  117. Tenant.embd_id,
  118. Tenant.img2txt_id,
  119. Tenant.asr_id,
  120. cls.model.update_time]
  121. docs = cls.model.select(*fields) \
  122. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  123. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
  124. .where(
  125. cls.model.status == StatusEnum.VALID.value,
  126. ~(cls.model.type == FileType.VIRTUAL.value),
  127. cls.model.progress == 0,
  128. cls.model.update_time >= current_timestamp() - 1000 * 600,
  129. cls.model.run == TaskStatus.RUNNING.value)\
  130. .order_by(cls.model.update_time.asc())
  131. return list(docs.dicts())
  132. @classmethod
  133. @DB.connection_context()
  134. def get_unfinished_docs(cls):
  135. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
  136. docs = cls.model.select(*fields) \
  137. .where(
  138. cls.model.status == StatusEnum.VALID.value,
  139. ~(cls.model.type == FileType.VIRTUAL.value),
  140. cls.model.progress < 1,
  141. cls.model.progress > 0)
  142. return list(docs.dicts())
  143. @classmethod
  144. @DB.connection_context()
  145. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  146. num = cls.model.update(token_num=cls.model.token_num + token_num,
  147. chunk_num=cls.model.chunk_num + chunk_num,
  148. process_duation=cls.model.process_duation + duation).where(
  149. cls.model.id == doc_id).execute()
  150. if num == 0:
  151. raise LookupError(
  152. "Document not found which is supposed to be there")
  153. num = Knowledgebase.update(
  154. token_num=Knowledgebase.token_num +
  155. token_num,
  156. chunk_num=Knowledgebase.chunk_num +
  157. chunk_num).where(
  158. Knowledgebase.id == kb_id).execute()
  159. return num
  160. @classmethod
  161. @DB.connection_context()
  162. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  163. num = cls.model.update(token_num=cls.model.token_num - token_num,
  164. chunk_num=cls.model.chunk_num - chunk_num,
  165. process_duation=cls.model.process_duation + duation).where(
  166. cls.model.id == doc_id).execute()
  167. if num == 0:
  168. raise LookupError(
  169. "Document not found which is supposed to be there")
  170. num = Knowledgebase.update(
  171. token_num=Knowledgebase.token_num -
  172. token_num,
  173. chunk_num=Knowledgebase.chunk_num -
  174. chunk_num
  175. ).where(
  176. Knowledgebase.id == kb_id).execute()
  177. return num
  178. @classmethod
  179. @DB.connection_context()
  180. def clear_chunk_num(cls, doc_id):
  181. doc = cls.model.get_by_id(doc_id)
  182. assert doc, "Can't fine document in database."
  183. num = Knowledgebase.update(
  184. token_num=Knowledgebase.token_num -
  185. doc.token_num,
  186. chunk_num=Knowledgebase.chunk_num -
  187. doc.chunk_num,
  188. doc_num=Knowledgebase.doc_num-1
  189. ).where(
  190. Knowledgebase.id == doc.kb_id).execute()
  191. return num
  192. @classmethod
  193. @DB.connection_context()
  194. def get_tenant_id(cls, doc_id):
  195. docs = cls.model.select(
  196. Knowledgebase.tenant_id).join(
  197. Knowledgebase, on=(
  198. Knowledgebase.id == cls.model.kb_id)).where(
  199. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  200. docs = docs.dicts()
  201. if not docs:
  202. return
  203. return docs[0]["tenant_id"]
  204. @classmethod
  205. @DB.connection_context()
  206. def get_tenant_id_by_name(cls, name):
  207. docs = cls.model.select(
  208. Knowledgebase.tenant_id).join(
  209. Knowledgebase, on=(
  210. Knowledgebase.id == cls.model.kb_id)).where(
  211. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  212. docs = docs.dicts()
  213. if not docs:
  214. return
  215. return docs[0]["tenant_id"]
  216. @classmethod
  217. @DB.connection_context()
  218. def get_embd_id(cls, doc_id):
  219. docs = cls.model.select(
  220. Knowledgebase.embd_id).join(
  221. Knowledgebase, on=(
  222. Knowledgebase.id == cls.model.kb_id)).where(
  223. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  224. docs = docs.dicts()
  225. if not docs:
  226. return
  227. return docs[0]["embd_id"]
  228. @classmethod
  229. @DB.connection_context()
  230. def get_doc_id_by_doc_name(cls, doc_name):
  231. fields = [cls.model.id]
  232. doc_id = cls.model.select(*fields) \
  233. .where(cls.model.name == doc_name)
  234. doc_id = doc_id.dicts()
  235. if not doc_id:
  236. return
  237. return doc_id[0]["id"]
  238. @classmethod
  239. @DB.connection_context()
  240. def get_thumbnails(cls, docids):
  241. fields = [cls.model.id, cls.model.thumbnail]
  242. return list(cls.model.select(
  243. *fields).where(cls.model.id.in_(docids)).dicts())
  244. @classmethod
  245. @DB.connection_context()
  246. def update_parser_config(cls, id, config):
  247. e, d = cls.get_by_id(id)
  248. if not e:
  249. raise LookupError(f"Document({id}) not found.")
  250. def dfs_update(old, new):
  251. for k, v in new.items():
  252. if k not in old:
  253. old[k] = v
  254. continue
  255. if isinstance(v, dict):
  256. assert isinstance(old[k], dict)
  257. dfs_update(old[k], v)
  258. else:
  259. old[k] = v
  260. dfs_update(d.parser_config, config)
  261. cls.update_by_id(id, {"parser_config": d.parser_config})
  262. @classmethod
  263. @DB.connection_context()
  264. def get_doc_count(cls, tenant_id):
  265. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  266. on=(Knowledgebase.id == cls.model.kb_id)).where(
  267. Knowledgebase.tenant_id == tenant_id)
  268. return len(docs)
  269. @classmethod
  270. @DB.connection_context()
  271. def begin2parse(cls, docid):
  272. cls.update_by_id(
  273. docid, {"progress": random.random() * 1 / 100.,
  274. "progress_msg": "Task dispatched...",
  275. "process_begin_at": get_format_time()
  276. })
  277. @classmethod
  278. @DB.connection_context()
  279. def update_progress(cls):
  280. docs = cls.get_unfinished_docs()
  281. for d in docs:
  282. try:
  283. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  284. if not tsks:
  285. continue
  286. msg = []
  287. prg = 0
  288. finished = True
  289. bad = 0
  290. e, doc = DocumentService.get_by_id(d["id"])
  291. status = doc.run#TaskStatus.RUNNING.value
  292. for t in tsks:
  293. if 0 <= t.progress < 1:
  294. finished = False
  295. prg += t.progress if t.progress >= 0 else 0
  296. if t.progress_msg not in msg:
  297. msg.append(t.progress_msg)
  298. if t.progress == -1:
  299. bad += 1
  300. prg /= len(tsks)
  301. if finished and bad:
  302. prg = -1
  303. status = TaskStatus.FAIL.value
  304. elif finished:
  305. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
  306. queue_raptor_tasks(d)
  307. prg *= 0.98
  308. msg.append("------ RAPTOR -------")
  309. else:
  310. status = TaskStatus.DONE.value
  311. msg = "\n".join(msg)
  312. info = {
  313. "process_duation": datetime.timestamp(
  314. datetime.now()) -
  315. d["process_begin_at"].timestamp(),
  316. "run": status}
  317. if prg != 0:
  318. info["progress"] = prg
  319. if msg:
  320. info["progress_msg"] = msg
  321. cls.update_by_id(d["id"], info)
  322. except Exception as e:
  323. stat_logger.error("fetch task exception:" + str(e))
  324. @classmethod
  325. @DB.connection_context()
  326. def get_kb_doc_count(cls, kb_id):
  327. return len(cls.model.select(cls.model.id).where(
  328. cls.model.kb_id == kb_id).dicts())
  329. @classmethod
  330. @DB.connection_context()
  331. def do_cancel(cls, doc_id):
  332. try:
  333. _, doc = DocumentService.get_by_id(doc_id)
  334. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  335. except Exception as e:
  336. pass
  337. return False
  338. def queue_raptor_tasks(doc):
  339. def new_task():
  340. nonlocal doc
  341. return {
  342. "id": get_uuid(),
  343. "doc_id": doc["id"],
  344. "from_page": 0,
  345. "to_page": -1,
  346. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
  347. }
  348. task = new_task()
  349. bulk_insert_into_db(Task, [task], True)
  350. task["type"] = "raptor"
  351. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  352. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  353. from rag.app import presentation, picture, naive, audio, email
  354. from api.db.services.dialog_service import ConversationService, DialogService
  355. from api.db.services.file_service import FileService
  356. from api.db.services.llm_service import LLMBundle
  357. from api.db.services.user_service import TenantService
  358. from api.db.services.api_service import API4ConversationService
  359. e, conv = ConversationService.get_by_id(conversation_id)
  360. if not e:
  361. e, conv = API4ConversationService.get_by_id(conversation_id)
  362. assert e, "Conversation not found!"
  363. e, dia = DialogService.get_by_id(conv.dialog_id)
  364. kb_id = dia.kb_ids[0]
  365. e, kb = KnowledgebaseService.get_by_id(kb_id)
  366. if not e:
  367. raise LookupError("Can't find this knowledgebase!")
  368. idxnm = search.index_name(kb.tenant_id)
  369. if not ELASTICSEARCH.indexExist(idxnm):
  370. ELASTICSEARCH.createIdx(idxnm, json.load(
  371. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  372. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  373. err, files = FileService.upload_document(kb, file_objs, user_id)
  374. assert not err, "\n".join(err)
  375. def dummy(prog=None, msg=""):
  376. pass
  377. FACTORY = {
  378. ParserType.PRESENTATION.value: presentation,
  379. ParserType.PICTURE.value: picture,
  380. ParserType.AUDIO.value: audio,
  381. ParserType.EMAIL.value: email
  382. }
  383. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  384. exe = ThreadPoolExecutor(max_workers=12)
  385. threads = []
  386. for d, blob in files:
  387. kwargs = {
  388. "callback": dummy,
  389. "parser_config": parser_config,
  390. "from_page": 0,
  391. "to_page": 100000,
  392. "tenant_id": kb.tenant_id,
  393. "lang": kb.language
  394. }
  395. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  396. for (docinfo, _), th in zip(files, threads):
  397. docs = []
  398. doc = {
  399. "doc_id": docinfo["id"],
  400. "kb_id": [kb.id]
  401. }
  402. for ck in th.result():
  403. d = deepcopy(doc)
  404. d.update(ck)
  405. md5 = hashlib.md5()
  406. md5.update((ck["content_with_weight"] +
  407. str(d["doc_id"])).encode("utf-8"))
  408. d["_id"] = md5.hexdigest()
  409. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  410. d["create_timestamp_flt"] = datetime.now().timestamp()
  411. if not d.get("image"):
  412. docs.append(d)
  413. continue
  414. output_buffer = BytesIO()
  415. if isinstance(d["image"], bytes):
  416. output_buffer = BytesIO(d["image"])
  417. else:
  418. d["image"].save(output_buffer, format='JPEG')
  419. MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
  420. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  421. del d["image"]
  422. docs.append(d)
  423. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  424. docids = [d["id"] for d, _ in files]
  425. chunk_counts = {id: 0 for id in docids}
  426. token_counts = {id: 0 for id in docids}
  427. es_bulk_size = 64
  428. def embedding(doc_id, cnts, batch_size=16):
  429. nonlocal embd_mdl, chunk_counts, token_counts
  430. vects = []
  431. for i in range(0, len(cnts), batch_size):
  432. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  433. vects.extend(vts.tolist())
  434. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  435. token_counts[doc_id] += c
  436. return vects
  437. _, tenant = TenantService.get_by_id(kb.tenant_id)
  438. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  439. for doc_id in docids:
  440. cks = [c for c in docs if c["doc_id"] == doc_id]
  441. if parser_ids[doc_id] != ParserType.PICTURE.value:
  442. mindmap = MindMapExtractor(llm_bdl)
  443. try:
  444. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  445. ensure_ascii=False, indent=2)
  446. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  447. cks.append({
  448. "id": get_uuid(),
  449. "doc_id": doc_id,
  450. "kb_id": [kb.id],
  451. "content_with_weight": mind_map,
  452. "knowledge_graph_kwd": "mind_map"
  453. })
  454. except Exception as e:
  455. stat_logger.error("Mind map generation error:", traceback.format_exc())
  456. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  457. assert len(cks) == len(vects)
  458. for i, d in enumerate(cks):
  459. v = vects[i]
  460. d["q_%d_vec" % len(v)] = v
  461. for b in range(0, len(cks), es_bulk_size):
  462. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  463. DocumentService.increment_chunk_num(
  464. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  465. return [d["id"] for d,_ in files]