Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

document_service.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import json
  18. import os
  19. import random
  20. import re
  21. import traceback
  22. from concurrent.futures import ThreadPoolExecutor
  23. from copy import deepcopy
  24. from datetime import datetime
  25. from io import BytesIO
  26. from elasticsearch_dsl import Q
  27. from peewee import fn
  28. from api.db.db_utils import bulk_insert_into_db
  29. from api.settings import stat_logger
  30. from api.utils import current_timestamp, get_format_time, get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. from graphrag.mind_map_extractor import MindMapExtractor
  33. from rag.settings import SVR_QUEUE_NAME
  34. from rag.utils.es_conn import ELASTICSEARCH
  35. from rag.utils.storage_factory import STORAGE_IMPL
  36. from rag.nlp import search, rag_tokenizer
  37. from api.db import FileType, TaskStatus, ParserType, LLMType
  38. from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
  39. from api.db.db_models import Document
  40. from api.db.services.common_service import CommonService
  41. from api.db.services.knowledgebase_service import KnowledgebaseService
  42. from api.db import StatusEnum
  43. from rag.utils.redis_conn import REDIS_CONN
  44. class DocumentService(CommonService):
  45. model = Document
  46. @classmethod
  47. @DB.connection_context()
  48. def get_list(cls, kb_id, page_number, items_per_page,
  49. orderby, desc, keywords, id, name):
  50. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  51. if id:
  52. docs = docs.where(
  53. cls.model.id == id)
  54. if name:
  55. docs = docs.where(
  56. cls.model.name == name
  57. )
  58. if keywords:
  59. docs = docs.where(
  60. fn.LOWER(cls.model.name).contains(keywords.lower())
  61. )
  62. if desc:
  63. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  64. else:
  65. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  66. docs = docs.paginate(page_number, items_per_page)
  67. count = docs.count()
  68. return list(docs.dicts()), count
  69. @classmethod
  70. @DB.connection_context()
  71. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  72. orderby, desc, keywords):
  73. if keywords:
  74. docs = cls.model.select().where(
  75. (cls.model.kb_id == kb_id),
  76. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  77. )
  78. else:
  79. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  80. count = docs.count()
  81. if desc:
  82. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  83. else:
  84. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  85. docs = docs.paginate(page_number, items_per_page)
  86. return list(docs.dicts()), count
  87. @classmethod
  88. @DB.connection_context()
  89. def insert(cls, doc):
  90. if not cls.save(**doc):
  91. raise RuntimeError("Database error (Document)!")
  92. e, doc = cls.get_by_id(doc["id"])
  93. if not e:
  94. raise RuntimeError("Database error (Document retrieval)!")
  95. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  96. if not KnowledgebaseService.update_by_id(
  97. kb.id, {"doc_num": kb.doc_num + 1}):
  98. raise RuntimeError("Database error (Knowledgebase)!")
  99. return doc
  100. @classmethod
  101. @DB.connection_context()
  102. def remove_document(cls, doc, tenant_id):
  103. ELASTICSEARCH.deleteByQuery(
  104. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  105. cls.clear_chunk_num(doc.id)
  106. return cls.delete_by_id(doc.id)
  107. @classmethod
  108. @DB.connection_context()
  109. def get_newly_uploaded(cls):
  110. fields = [
  111. cls.model.id,
  112. cls.model.kb_id,
  113. cls.model.parser_id,
  114. cls.model.parser_config,
  115. cls.model.name,
  116. cls.model.type,
  117. cls.model.location,
  118. cls.model.size,
  119. Knowledgebase.tenant_id,
  120. Tenant.embd_id,
  121. Tenant.img2txt_id,
  122. Tenant.asr_id,
  123. cls.model.update_time]
  124. docs = cls.model.select(*fields) \
  125. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  126. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  127. .where(
  128. cls.model.status == StatusEnum.VALID.value,
  129. ~(cls.model.type == FileType.VIRTUAL.value),
  130. cls.model.progress == 0,
  131. cls.model.update_time >= current_timestamp() - 1000 * 600,
  132. cls.model.run == TaskStatus.RUNNING.value) \
  133. .order_by(cls.model.update_time.asc())
  134. return list(docs.dicts())
  135. @classmethod
  136. @DB.connection_context()
  137. def get_unfinished_docs(cls):
  138. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  139. cls.model.run]
  140. docs = cls.model.select(*fields) \
  141. .where(
  142. cls.model.status == StatusEnum.VALID.value,
  143. ~(cls.model.type == FileType.VIRTUAL.value),
  144. cls.model.progress < 1,
  145. cls.model.progress > 0)
  146. return list(docs.dicts())
  147. @classmethod
  148. @DB.connection_context()
  149. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  150. num = cls.model.update(token_num=cls.model.token_num + token_num,
  151. chunk_num=cls.model.chunk_num + chunk_num,
  152. process_duation=cls.model.process_duation + duation).where(
  153. cls.model.id == doc_id).execute()
  154. if num == 0:
  155. raise LookupError(
  156. "Document not found which is supposed to be there")
  157. num = Knowledgebase.update(
  158. token_num=Knowledgebase.token_num +
  159. token_num,
  160. chunk_num=Knowledgebase.chunk_num +
  161. chunk_num).where(
  162. Knowledgebase.id == kb_id).execute()
  163. return num
  164. @classmethod
  165. @DB.connection_context()
  166. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  167. num = cls.model.update(token_num=cls.model.token_num - token_num,
  168. chunk_num=cls.model.chunk_num - chunk_num,
  169. process_duation=cls.model.process_duation + duation).where(
  170. cls.model.id == doc_id).execute()
  171. if num == 0:
  172. raise LookupError(
  173. "Document not found which is supposed to be there")
  174. num = Knowledgebase.update(
  175. token_num=Knowledgebase.token_num -
  176. token_num,
  177. chunk_num=Knowledgebase.chunk_num -
  178. chunk_num
  179. ).where(
  180. Knowledgebase.id == kb_id).execute()
  181. return num
  182. @classmethod
  183. @DB.connection_context()
  184. def clear_chunk_num(cls, doc_id):
  185. doc = cls.model.get_by_id(doc_id)
  186. assert doc, "Can't fine document in database."
  187. num = Knowledgebase.update(
  188. token_num=Knowledgebase.token_num -
  189. doc.token_num,
  190. chunk_num=Knowledgebase.chunk_num -
  191. doc.chunk_num,
  192. doc_num=Knowledgebase.doc_num - 1
  193. ).where(
  194. Knowledgebase.id == doc.kb_id).execute()
  195. return num
  196. @classmethod
  197. @DB.connection_context()
  198. def get_tenant_id(cls, doc_id):
  199. docs = cls.model.select(
  200. Knowledgebase.tenant_id).join(
  201. Knowledgebase, on=(
  202. Knowledgebase.id == cls.model.kb_id)).where(
  203. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  204. docs = docs.dicts()
  205. if not docs:
  206. return
  207. return docs[0]["tenant_id"]
  208. @classmethod
  209. @DB.connection_context()
  210. def get_tenant_id_by_name(cls, name):
  211. docs = cls.model.select(
  212. Knowledgebase.tenant_id).join(
  213. Knowledgebase, on=(
  214. Knowledgebase.id == cls.model.kb_id)).where(
  215. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  216. docs = docs.dicts()
  217. if not docs:
  218. return
  219. return docs[0]["tenant_id"]
  220. @classmethod
  221. @DB.connection_context()
  222. def accessible(cls, doc_id, user_id):
  223. docs = cls.model.select(
  224. cls.model.id).join(
  225. Knowledgebase, on=(
  226. Knowledgebase.id == cls.model.kb_id)
  227. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  228. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  229. docs = docs.dicts()
  230. if not docs:
  231. return False
  232. return True
  233. @classmethod
  234. @DB.connection_context()
  235. def accessible4deletion(cls, doc_id, user_id):
  236. docs = cls.model.select(
  237. cls.model.id).join(
  238. Knowledgebase, on=(
  239. Knowledgebase.id == cls.model.kb_id)
  240. ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
  241. docs = docs.dicts()
  242. if not docs:
  243. return False
  244. return True
  245. @classmethod
  246. @DB.connection_context()
  247. def get_embd_id(cls, doc_id):
  248. docs = cls.model.select(
  249. Knowledgebase.embd_id).join(
  250. Knowledgebase, on=(
  251. Knowledgebase.id == cls.model.kb_id)).where(
  252. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  253. docs = docs.dicts()
  254. if not docs:
  255. return
  256. return docs[0]["embd_id"]
  257. @classmethod
  258. @DB.connection_context()
  259. def get_doc_id_by_doc_name(cls, doc_name):
  260. fields = [cls.model.id]
  261. doc_id = cls.model.select(*fields) \
  262. .where(cls.model.name == doc_name)
  263. doc_id = doc_id.dicts()
  264. if not doc_id:
  265. return
  266. return doc_id[0]["id"]
  267. @classmethod
  268. @DB.connection_context()
  269. def get_thumbnails(cls, docids):
  270. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  271. return list(cls.model.select(
  272. *fields).where(cls.model.id.in_(docids)).dicts())
  273. @classmethod
  274. @DB.connection_context()
  275. def update_parser_config(cls, id, config):
  276. e, d = cls.get_by_id(id)
  277. if not e:
  278. raise LookupError(f"Document({id}) not found.")
  279. def dfs_update(old, new):
  280. for k, v in new.items():
  281. if k not in old:
  282. old[k] = v
  283. continue
  284. if isinstance(v, dict):
  285. assert isinstance(old[k], dict)
  286. dfs_update(old[k], v)
  287. else:
  288. old[k] = v
  289. dfs_update(d.parser_config, config)
  290. cls.update_by_id(id, {"parser_config": d.parser_config})
  291. @classmethod
  292. @DB.connection_context()
  293. def get_doc_count(cls, tenant_id):
  294. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  295. on=(Knowledgebase.id == cls.model.kb_id)).where(
  296. Knowledgebase.tenant_id == tenant_id)
  297. return len(docs)
  298. @classmethod
  299. @DB.connection_context()
  300. def begin2parse(cls, docid):
  301. cls.update_by_id(
  302. docid, {"progress": random.random() * 1 / 100.,
  303. "progress_msg": "Task dispatched...",
  304. "process_begin_at": get_format_time()
  305. })
  306. @classmethod
  307. @DB.connection_context()
  308. def update_progress(cls):
  309. docs = cls.get_unfinished_docs()
  310. for d in docs:
  311. try:
  312. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  313. if not tsks:
  314. continue
  315. msg = []
  316. prg = 0
  317. finished = True
  318. bad = 0
  319. e, doc = DocumentService.get_by_id(d["id"])
  320. status = doc.run # TaskStatus.RUNNING.value
  321. for t in tsks:
  322. if 0 <= t.progress < 1:
  323. finished = False
  324. prg += t.progress if t.progress >= 0 else 0
  325. if t.progress_msg not in msg:
  326. msg.append(t.progress_msg)
  327. if t.progress == -1:
  328. bad += 1
  329. prg /= len(tsks)
  330. if finished and bad:
  331. prg = -1
  332. status = TaskStatus.FAIL.value
  333. elif finished:
  334. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
  335. " raptor") < 0:
  336. queue_raptor_tasks(d)
  337. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  338. msg.append("------ RAPTOR -------")
  339. else:
  340. status = TaskStatus.DONE.value
  341. msg = "\n".join(msg)
  342. info = {
  343. "process_duation": datetime.timestamp(
  344. datetime.now()) -
  345. d["process_begin_at"].timestamp(),
  346. "run": status}
  347. if prg != 0:
  348. info["progress"] = prg
  349. if msg:
  350. info["progress_msg"] = msg
  351. cls.update_by_id(d["id"], info)
  352. except Exception as e:
  353. if str(e).find("'0'") < 0:
  354. stat_logger.error("fetch task exception:" + str(e))
  355. @classmethod
  356. @DB.connection_context()
  357. def get_kb_doc_count(cls, kb_id):
  358. return len(cls.model.select(cls.model.id).where(
  359. cls.model.kb_id == kb_id).dicts())
  360. @classmethod
  361. @DB.connection_context()
  362. def do_cancel(cls, doc_id):
  363. try:
  364. _, doc = DocumentService.get_by_id(doc_id)
  365. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  366. except Exception:
  367. pass
  368. return False
  369. def queue_raptor_tasks(doc):
  370. def new_task():
  371. nonlocal doc
  372. return {
  373. "id": get_uuid(),
  374. "doc_id": doc["id"],
  375. "from_page": 0,
  376. "to_page": -1,
  377. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
  378. }
  379. task = new_task()
  380. bulk_insert_into_db(Task, [task], True)
  381. task["type"] = "raptor"
  382. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  383. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  384. from rag.app import presentation, picture, naive, audio, email
  385. from api.db.services.dialog_service import ConversationService, DialogService
  386. from api.db.services.file_service import FileService
  387. from api.db.services.llm_service import LLMBundle
  388. from api.db.services.user_service import TenantService
  389. from api.db.services.api_service import API4ConversationService
  390. e, conv = ConversationService.get_by_id(conversation_id)
  391. if not e:
  392. e, conv = API4ConversationService.get_by_id(conversation_id)
  393. assert e, "Conversation not found!"
  394. e, dia = DialogService.get_by_id(conv.dialog_id)
  395. kb_id = dia.kb_ids[0]
  396. e, kb = KnowledgebaseService.get_by_id(kb_id)
  397. if not e:
  398. raise LookupError("Can't find this knowledgebase!")
  399. idxnm = search.index_name(kb.tenant_id)
  400. if not ELASTICSEARCH.indexExist(idxnm):
  401. ELASTICSEARCH.createIdx(idxnm, json.load(
  402. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  403. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  404. err, files = FileService.upload_document(kb, file_objs, user_id)
  405. assert not err, "\n".join(err)
  406. def dummy(prog=None, msg=""):
  407. pass
  408. FACTORY = {
  409. ParserType.PRESENTATION.value: presentation,
  410. ParserType.PICTURE.value: picture,
  411. ParserType.AUDIO.value: audio,
  412. ParserType.EMAIL.value: email
  413. }
  414. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  415. exe = ThreadPoolExecutor(max_workers=12)
  416. threads = []
  417. doc_nm = {}
  418. for d, blob in files:
  419. doc_nm[d["id"]] = d["name"]
  420. for d, blob in files:
  421. kwargs = {
  422. "callback": dummy,
  423. "parser_config": parser_config,
  424. "from_page": 0,
  425. "to_page": 100000,
  426. "tenant_id": kb.tenant_id,
  427. "lang": kb.language
  428. }
  429. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  430. for (docinfo, _), th in zip(files, threads):
  431. docs = []
  432. doc = {
  433. "doc_id": docinfo["id"],
  434. "kb_id": [kb.id]
  435. }
  436. for ck in th.result():
  437. d = deepcopy(doc)
  438. d.update(ck)
  439. md5 = hashlib.md5()
  440. md5.update((ck["content_with_weight"] +
  441. str(d["doc_id"])).encode("utf-8"))
  442. d["_id"] = md5.hexdigest()
  443. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  444. d["create_timestamp_flt"] = datetime.now().timestamp()
  445. if not d.get("image"):
  446. docs.append(d)
  447. continue
  448. output_buffer = BytesIO()
  449. if isinstance(d["image"], bytes):
  450. output_buffer = BytesIO(d["image"])
  451. else:
  452. d["image"].save(output_buffer, format='JPEG')
  453. STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
  454. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  455. del d["image"]
  456. docs.append(d)
  457. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  458. docids = [d["id"] for d, _ in files]
  459. chunk_counts = {id: 0 for id in docids}
  460. token_counts = {id: 0 for id in docids}
  461. es_bulk_size = 64
  462. def embedding(doc_id, cnts, batch_size=16):
  463. nonlocal embd_mdl, chunk_counts, token_counts
  464. vects = []
  465. for i in range(0, len(cnts), batch_size):
  466. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  467. vects.extend(vts.tolist())
  468. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  469. token_counts[doc_id] += c
  470. return vects
  471. _, tenant = TenantService.get_by_id(kb.tenant_id)
  472. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  473. for doc_id in docids:
  474. cks = [c for c in docs if c["doc_id"] == doc_id]
  475. if parser_ids[doc_id] != ParserType.PICTURE.value:
  476. mindmap = MindMapExtractor(llm_bdl)
  477. try:
  478. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  479. ensure_ascii=False, indent=2)
  480. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  481. cks.append({
  482. "id": get_uuid(),
  483. "doc_id": doc_id,
  484. "kb_id": [kb.id],
  485. "docnm_kwd": doc_nm[doc_id],
  486. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  487. "content_ltks": "",
  488. "content_with_weight": mind_map,
  489. "knowledge_graph_kwd": "mind_map"
  490. })
  491. except Exception as e:
  492. stat_logger.error("Mind map generation error:", traceback.format_exc())
  493. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  494. assert len(cks) == len(vects)
  495. for i, d in enumerate(cks):
  496. v = vects[i]
  497. d["q_%d_vec" % len(v)] = v
  498. for b in range(0, len(cks), es_bulk_size):
  499. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  500. DocumentService.increment_chunk_num(
  501. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  502. return [d["id"] for d, _ in files]