Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

document_service.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. import re
  21. import traceback
  22. from concurrent.futures import ThreadPoolExecutor
  23. from copy import deepcopy
  24. from datetime import datetime
  25. from io import BytesIO
  26. from elasticsearch_dsl import Q
  27. from peewee import fn
  28. from api.db.db_utils import bulk_insert_into_db
  29. from api.settings import stat_logger
  30. from api.utils import current_timestamp, get_format_time, get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. from graphrag.mind_map_extractor import MindMapExtractor
  33. from rag.settings import SVR_QUEUE_NAME
  34. from rag.utils.es_conn import ELASTICSEARCH
  35. from rag.utils.storage_factory import STORAGE_IMPL
  36. from rag.nlp import search, rag_tokenizer
  37. from api.db import FileType, TaskStatus, ParserType, LLMType
  38. from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
  39. from api.db.db_models import Document
  40. from api.db.services.common_service import CommonService
  41. from api.db.services.knowledgebase_service import KnowledgebaseService
  42. from api.db import StatusEnum
  43. from rag.utils.redis_conn import REDIS_CONN
  44. class DocumentService(CommonService):
  45. model = Document
  46. @classmethod
  47. @DB.connection_context()
  48. def get_list(cls, kb_id, page_number, items_per_page,
  49. orderby, desc, keywords, id):
  50. docs =cls.model.select().where(cls.model.kb_id==kb_id)
  51. if id:
  52. docs = docs.where(
  53. cls.model.id== id )
  54. if keywords:
  55. docs = docs.where(
  56. fn.LOWER(cls.model.name).contains(keywords.lower())
  57. )
  58. if desc:
  59. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  60. else:
  61. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  62. docs = docs.paginate(page_number, items_per_page)
  63. count = docs.count()
  64. return list(docs.dicts()), count
  65. @classmethod
  66. @DB.connection_context()
  67. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  68. orderby, desc, keywords):
  69. if keywords:
  70. docs = cls.model.select().where(
  71. (cls.model.kb_id == kb_id),
  72. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  73. )
  74. else:
  75. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  76. count = docs.count()
  77. if desc:
  78. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  79. else:
  80. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  81. docs = docs.paginate(page_number, items_per_page)
  82. return list(docs.dicts()), count
  83. @classmethod
  84. @DB.connection_context()
  85. def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
  86. if keywords:
  87. docs = cls.model.select().where(
  88. (cls.model.kb_id == dataset_id),
  89. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  90. )
  91. else:
  92. docs = cls.model.select().where(cls.model.kb_id == dataset_id)
  93. total = docs.count()
  94. if descend == 'True':
  95. docs = docs.order_by(cls.model.getter_by(order_by).desc())
  96. if descend == 'False':
  97. docs = docs.order_by(cls.model.getter_by(order_by).asc())
  98. docs = list(docs.dicts())
  99. docs_length = len(docs)
  100. if offset < 0 or offset > docs_length:
  101. raise IndexError("Offset is out of the valid range.")
  102. if count == -1:
  103. return docs[offset:], total
  104. return docs[offset:offset + count], total
  105. @classmethod
  106. @DB.connection_context()
  107. def insert(cls, doc):
  108. if not cls.save(**doc):
  109. raise RuntimeError("Database error (Document)!")
  110. e, doc = cls.get_by_id(doc["id"])
  111. if not e:
  112. raise RuntimeError("Database error (Document retrieval)!")
  113. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  114. if not KnowledgebaseService.update_by_id(
  115. kb.id, {"doc_num": kb.doc_num + 1}):
  116. raise RuntimeError("Database error (Knowledgebase)!")
  117. return doc
  118. @classmethod
  119. @DB.connection_context()
  120. def remove_document(cls, doc, tenant_id):
  121. ELASTICSEARCH.deleteByQuery(
  122. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  123. cls.clear_chunk_num(doc.id)
  124. return cls.delete_by_id(doc.id)
  125. @classmethod
  126. @DB.connection_context()
  127. def get_newly_uploaded(cls):
  128. fields = [
  129. cls.model.id,
  130. cls.model.kb_id,
  131. cls.model.parser_id,
  132. cls.model.parser_config,
  133. cls.model.name,
  134. cls.model.type,
  135. cls.model.location,
  136. cls.model.size,
  137. Knowledgebase.tenant_id,
  138. Tenant.embd_id,
  139. Tenant.img2txt_id,
  140. Tenant.asr_id,
  141. cls.model.update_time]
  142. docs = cls.model.select(*fields) \
  143. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  144. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
  145. .where(
  146. cls.model.status == StatusEnum.VALID.value,
  147. ~(cls.model.type == FileType.VIRTUAL.value),
  148. cls.model.progress == 0,
  149. cls.model.update_time >= current_timestamp() - 1000 * 600,
  150. cls.model.run == TaskStatus.RUNNING.value)\
  151. .order_by(cls.model.update_time.asc())
  152. return list(docs.dicts())
  153. @classmethod
  154. @DB.connection_context()
  155. def get_unfinished_docs(cls):
  156. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
  157. docs = cls.model.select(*fields) \
  158. .where(
  159. cls.model.status == StatusEnum.VALID.value,
  160. ~(cls.model.type == FileType.VIRTUAL.value),
  161. cls.model.progress < 1,
  162. cls.model.progress > 0)
  163. return list(docs.dicts())
  164. @classmethod
  165. @DB.connection_context()
  166. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  167. num = cls.model.update(token_num=cls.model.token_num + token_num,
  168. chunk_num=cls.model.chunk_num + chunk_num,
  169. process_duation=cls.model.process_duation + duation).where(
  170. cls.model.id == doc_id).execute()
  171. if num == 0:
  172. raise LookupError(
  173. "Document not found which is supposed to be there")
  174. num = Knowledgebase.update(
  175. token_num=Knowledgebase.token_num +
  176. token_num,
  177. chunk_num=Knowledgebase.chunk_num +
  178. chunk_num).where(
  179. Knowledgebase.id == kb_id).execute()
  180. return num
  181. @classmethod
  182. @DB.connection_context()
  183. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  184. num = cls.model.update(token_num=cls.model.token_num - token_num,
  185. chunk_num=cls.model.chunk_num - chunk_num,
  186. process_duation=cls.model.process_duation + duation).where(
  187. cls.model.id == doc_id).execute()
  188. if num == 0:
  189. raise LookupError(
  190. "Document not found which is supposed to be there")
  191. num = Knowledgebase.update(
  192. token_num=Knowledgebase.token_num -
  193. token_num,
  194. chunk_num=Knowledgebase.chunk_num -
  195. chunk_num
  196. ).where(
  197. Knowledgebase.id == kb_id).execute()
  198. return num
  199. @classmethod
  200. @DB.connection_context()
  201. def clear_chunk_num(cls, doc_id):
  202. doc = cls.model.get_by_id(doc_id)
  203. assert doc, "Can't fine document in database."
  204. num = Knowledgebase.update(
  205. token_num=Knowledgebase.token_num -
  206. doc.token_num,
  207. chunk_num=Knowledgebase.chunk_num -
  208. doc.chunk_num,
  209. doc_num=Knowledgebase.doc_num-1
  210. ).where(
  211. Knowledgebase.id == doc.kb_id).execute()
  212. return num
  213. @classmethod
  214. @DB.connection_context()
  215. def get_tenant_id(cls, doc_id):
  216. docs = cls.model.select(
  217. Knowledgebase.tenant_id).join(
  218. Knowledgebase, on=(
  219. Knowledgebase.id == cls.model.kb_id)).where(
  220. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  221. docs = docs.dicts()
  222. if not docs:
  223. return
  224. return docs[0]["tenant_id"]
  225. @classmethod
  226. @DB.connection_context()
  227. def get_tenant_id_by_name(cls, name):
  228. docs = cls.model.select(
  229. Knowledgebase.tenant_id).join(
  230. Knowledgebase, on=(
  231. Knowledgebase.id == cls.model.kb_id)).where(
  232. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  233. docs = docs.dicts()
  234. if not docs:
  235. return
  236. return docs[0]["tenant_id"]
  237. @classmethod
  238. @DB.connection_context()
  239. def accessible(cls, doc_id, user_id):
  240. docs = cls.model.select(
  241. cls.model.id).join(
  242. Knowledgebase, on=(
  243. Knowledgebase.id == cls.model.kb_id)
  244. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  245. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  246. docs = docs.dicts()
  247. if not docs:
  248. return False
  249. return True
  250. @classmethod
  251. @DB.connection_context()
  252. def accessible4deletion(cls, doc_id, user_id):
  253. docs = cls.model.select(
  254. cls.model.id).join(
  255. Knowledgebase, on=(
  256. Knowledgebase.id == cls.model.kb_id)
  257. ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
  258. docs = docs.dicts()
  259. if not docs:
  260. return False
  261. return True
  262. @classmethod
  263. @DB.connection_context()
  264. def get_embd_id(cls, doc_id):
  265. docs = cls.model.select(
  266. Knowledgebase.embd_id).join(
  267. Knowledgebase, on=(
  268. Knowledgebase.id == cls.model.kb_id)).where(
  269. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  270. docs = docs.dicts()
  271. if not docs:
  272. return
  273. return docs[0]["embd_id"]
  274. @classmethod
  275. @DB.connection_context()
  276. def get_doc_id_by_doc_name(cls, doc_name):
  277. fields = [cls.model.id]
  278. doc_id = cls.model.select(*fields) \
  279. .where(cls.model.name == doc_name)
  280. doc_id = doc_id.dicts()
  281. if not doc_id:
  282. return
  283. return doc_id[0]["id"]
  284. @classmethod
  285. @DB.connection_context()
  286. def get_thumbnails(cls, docids):
  287. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  288. return list(cls.model.select(
  289. *fields).where(cls.model.id.in_(docids)).dicts())
  290. @classmethod
  291. @DB.connection_context()
  292. def update_parser_config(cls, id, config):
  293. e, d = cls.get_by_id(id)
  294. if not e:
  295. raise LookupError(f"Document({id}) not found.")
  296. def dfs_update(old, new):
  297. for k, v in new.items():
  298. if k not in old:
  299. old[k] = v
  300. continue
  301. if isinstance(v, dict):
  302. assert isinstance(old[k], dict)
  303. dfs_update(old[k], v)
  304. else:
  305. old[k] = v
  306. dfs_update(d.parser_config, config)
  307. cls.update_by_id(id, {"parser_config": d.parser_config})
  308. @classmethod
  309. @DB.connection_context()
  310. def get_doc_count(cls, tenant_id):
  311. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  312. on=(Knowledgebase.id == cls.model.kb_id)).where(
  313. Knowledgebase.tenant_id == tenant_id)
  314. return len(docs)
  315. @classmethod
  316. @DB.connection_context()
  317. def begin2parse(cls, docid):
  318. cls.update_by_id(
  319. docid, {"progress": random.random() * 1 / 100.,
  320. "progress_msg": "Task dispatched...",
  321. "process_begin_at": get_format_time()
  322. })
  323. @classmethod
  324. @DB.connection_context()
  325. def update_progress(cls):
  326. docs = cls.get_unfinished_docs()
  327. for d in docs:
  328. try:
  329. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  330. if not tsks:
  331. continue
  332. msg = []
  333. prg = 0
  334. finished = True
  335. bad = 0
  336. e, doc = DocumentService.get_by_id(d["id"])
  337. status = doc.run#TaskStatus.RUNNING.value
  338. for t in tsks:
  339. if 0 <= t.progress < 1:
  340. finished = False
  341. prg += t.progress if t.progress >= 0 else 0
  342. if t.progress_msg not in msg:
  343. msg.append(t.progress_msg)
  344. if t.progress == -1:
  345. bad += 1
  346. prg /= len(tsks)
  347. if finished and bad:
  348. prg = -1
  349. status = TaskStatus.FAIL.value
  350. elif finished:
  351. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
  352. queue_raptor_tasks(d)
  353. prg = 0.98 * len(tsks)/(len(tsks)+1)
  354. msg.append("------ RAPTOR -------")
  355. else:
  356. status = TaskStatus.DONE.value
  357. msg = "\n".join(msg)
  358. info = {
  359. "process_duation": datetime.timestamp(
  360. datetime.now()) -
  361. d["process_begin_at"].timestamp(),
  362. "run": status}
  363. if prg != 0:
  364. info["progress"] = prg
  365. if msg:
  366. info["progress_msg"] = msg
  367. cls.update_by_id(d["id"], info)
  368. except Exception as e:
  369. if str(e).find("'0'") < 0:
  370. stat_logger.error("fetch task exception:" + str(e))
  371. @classmethod
  372. @DB.connection_context()
  373. def get_kb_doc_count(cls, kb_id):
  374. return len(cls.model.select(cls.model.id).where(
  375. cls.model.kb_id == kb_id).dicts())
  376. @classmethod
  377. @DB.connection_context()
  378. def do_cancel(cls, doc_id):
  379. try:
  380. _, doc = DocumentService.get_by_id(doc_id)
  381. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  382. except Exception as e:
  383. pass
  384. return False
  385. def queue_raptor_tasks(doc):
  386. def new_task():
  387. nonlocal doc
  388. return {
  389. "id": get_uuid(),
  390. "doc_id": doc["id"],
  391. "from_page": 0,
  392. "to_page": -1,
  393. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
  394. }
  395. task = new_task()
  396. bulk_insert_into_db(Task, [task], True)
  397. task["type"] = "raptor"
  398. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  399. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  400. from rag.app import presentation, picture, naive, audio, email
  401. from api.db.services.dialog_service import ConversationService, DialogService
  402. from api.db.services.file_service import FileService
  403. from api.db.services.llm_service import LLMBundle
  404. from api.db.services.user_service import TenantService
  405. from api.db.services.api_service import API4ConversationService
  406. e, conv = ConversationService.get_by_id(conversation_id)
  407. if not e:
  408. e, conv = API4ConversationService.get_by_id(conversation_id)
  409. assert e, "Conversation not found!"
  410. e, dia = DialogService.get_by_id(conv.dialog_id)
  411. kb_id = dia.kb_ids[0]
  412. e, kb = KnowledgebaseService.get_by_id(kb_id)
  413. if not e:
  414. raise LookupError("Can't find this knowledgebase!")
  415. idxnm = search.index_name(kb.tenant_id)
  416. if not ELASTICSEARCH.indexExist(idxnm):
  417. ELASTICSEARCH.createIdx(idxnm, json.load(
  418. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  419. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  420. err, files = FileService.upload_document(kb, file_objs, user_id)
  421. assert not err, "\n".join(err)
  422. def dummy(prog=None, msg=""):
  423. pass
  424. FACTORY = {
  425. ParserType.PRESENTATION.value: presentation,
  426. ParserType.PICTURE.value: picture,
  427. ParserType.AUDIO.value: audio,
  428. ParserType.EMAIL.value: email
  429. }
  430. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  431. exe = ThreadPoolExecutor(max_workers=12)
  432. threads = []
  433. doc_nm = {}
  434. for d, blob in files:
  435. doc_nm[d["id"]] = d["name"]
  436. for d, blob in files:
  437. kwargs = {
  438. "callback": dummy,
  439. "parser_config": parser_config,
  440. "from_page": 0,
  441. "to_page": 100000,
  442. "tenant_id": kb.tenant_id,
  443. "lang": kb.language
  444. }
  445. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  446. for (docinfo, _), th in zip(files, threads):
  447. docs = []
  448. doc = {
  449. "doc_id": docinfo["id"],
  450. "kb_id": [kb.id]
  451. }
  452. for ck in th.result():
  453. d = deepcopy(doc)
  454. d.update(ck)
  455. md5 = hashlib.md5()
  456. md5.update((ck["content_with_weight"] +
  457. str(d["doc_id"])).encode("utf-8"))
  458. d["_id"] = md5.hexdigest()
  459. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  460. d["create_timestamp_flt"] = datetime.now().timestamp()
  461. if not d.get("image"):
  462. docs.append(d)
  463. continue
  464. output_buffer = BytesIO()
  465. if isinstance(d["image"], bytes):
  466. output_buffer = BytesIO(d["image"])
  467. else:
  468. d["image"].save(output_buffer, format='JPEG')
  469. STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
  470. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  471. del d["image"]
  472. docs.append(d)
  473. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  474. docids = [d["id"] for d, _ in files]
  475. chunk_counts = {id: 0 for id in docids}
  476. token_counts = {id: 0 for id in docids}
  477. es_bulk_size = 64
  478. def embedding(doc_id, cnts, batch_size=16):
  479. nonlocal embd_mdl, chunk_counts, token_counts
  480. vects = []
  481. for i in range(0, len(cnts), batch_size):
  482. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  483. vects.extend(vts.tolist())
  484. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  485. token_counts[doc_id] += c
  486. return vects
  487. _, tenant = TenantService.get_by_id(kb.tenant_id)
  488. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  489. for doc_id in docids:
  490. cks = [c for c in docs if c["doc_id"] == doc_id]
  491. if parser_ids[doc_id] != ParserType.PICTURE.value:
  492. mindmap = MindMapExtractor(llm_bdl)
  493. try:
  494. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  495. ensure_ascii=False, indent=2)
  496. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  497. cks.append({
  498. "id": get_uuid(),
  499. "doc_id": doc_id,
  500. "kb_id": [kb.id],
  501. "docnm_kwd": doc_nm[doc_id],
  502. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  503. "content_ltks": "",
  504. "content_with_weight": mind_map,
  505. "knowledge_graph_kwd": "mind_map"
  506. })
  507. except Exception as e:
  508. stat_logger.error("Mind map generation error:", traceback.format_exc())
  509. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  510. assert len(cks) == len(vects)
  511. for i, d in enumerate(cks):
  512. v = vects[i]
  513. d["q_%d_vec" % len(v)] = v
  514. for b in range(0, len(cks), es_bulk_size):
  515. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  516. DocumentService.increment_chunk_num(
  517. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  518. return [d["id"] for d,_ in files]