Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

document_service.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import hashlib
  18. import json
  19. import random
  20. import re
  21. from concurrent.futures import ThreadPoolExecutor
  22. from copy import deepcopy
  23. from datetime import datetime
  24. from io import BytesIO
  25. from peewee import fn
  26. from api.db.db_utils import bulk_insert_into_db
  27. from api import settings
  28. from api.utils import current_timestamp, get_format_time, get_uuid
  29. from graphrag.mind_map_extractor import MindMapExtractor
  30. from rag.settings import SVR_QUEUE_NAME
  31. from rag.utils.storage_factory import STORAGE_IMPL
  32. from rag.nlp import search, rag_tokenizer
  33. from api.db import FileType, TaskStatus, ParserType, LLMType
  34. from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
  35. from api.db.db_models import Document
  36. from api.db.services.common_service import CommonService
  37. from api.db.services.knowledgebase_service import KnowledgebaseService
  38. from api.db import StatusEnum
  39. from rag.utils.redis_conn import REDIS_CONN
  40. class DocumentService(CommonService):
  41. model = Document
  42. @classmethod
  43. @DB.connection_context()
  44. def get_list(cls, kb_id, page_number, items_per_page,
  45. orderby, desc, keywords, id, name):
  46. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  47. if id:
  48. docs = docs.where(
  49. cls.model.id == id)
  50. if name:
  51. docs = docs.where(
  52. cls.model.name == name
  53. )
  54. if keywords:
  55. docs = docs.where(
  56. fn.LOWER(cls.model.name).contains(keywords.lower())
  57. )
  58. if desc:
  59. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  60. else:
  61. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  62. docs = docs.paginate(page_number, items_per_page)
  63. count = docs.count()
  64. return list(docs.dicts()), count
  65. @classmethod
  66. @DB.connection_context()
  67. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  68. orderby, desc, keywords):
  69. if keywords:
  70. docs = cls.model.select().where(
  71. (cls.model.kb_id == kb_id),
  72. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  73. )
  74. else:
  75. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  76. count = docs.count()
  77. if desc:
  78. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  79. else:
  80. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  81. docs = docs.paginate(page_number, items_per_page)
  82. return list(docs.dicts()), count
  83. @classmethod
  84. @DB.connection_context()
  85. def insert(cls, doc):
  86. if not cls.save(**doc):
  87. raise RuntimeError("Database error (Document)!")
  88. e, doc = cls.get_by_id(doc["id"])
  89. if not e:
  90. raise RuntimeError("Database error (Document retrieval)!")
  91. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  92. if not KnowledgebaseService.update_by_id(
  93. kb.id, {"doc_num": kb.doc_num + 1}):
  94. raise RuntimeError("Database error (Knowledgebase)!")
  95. return doc
  96. @classmethod
  97. @DB.connection_context()
  98. def remove_document(cls, doc, tenant_id):
  99. settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
  100. cls.clear_chunk_num(doc.id)
  101. return cls.delete_by_id(doc.id)
  102. @classmethod
  103. @DB.connection_context()
  104. def get_newly_uploaded(cls):
  105. fields = [
  106. cls.model.id,
  107. cls.model.kb_id,
  108. cls.model.parser_id,
  109. cls.model.parser_config,
  110. cls.model.name,
  111. cls.model.type,
  112. cls.model.location,
  113. cls.model.size,
  114. Knowledgebase.tenant_id,
  115. Tenant.embd_id,
  116. Tenant.img2txt_id,
  117. Tenant.asr_id,
  118. cls.model.update_time]
  119. docs = cls.model.select(*fields) \
  120. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  121. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  122. .where(
  123. cls.model.status == StatusEnum.VALID.value,
  124. ~(cls.model.type == FileType.VIRTUAL.value),
  125. cls.model.progress == 0,
  126. cls.model.update_time >= current_timestamp() - 1000 * 600,
  127. cls.model.run == TaskStatus.RUNNING.value) \
  128. .order_by(cls.model.update_time.asc())
  129. return list(docs.dicts())
  130. @classmethod
  131. @DB.connection_context()
  132. def get_unfinished_docs(cls):
  133. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  134. cls.model.run]
  135. docs = cls.model.select(*fields) \
  136. .where(
  137. cls.model.status == StatusEnum.VALID.value,
  138. ~(cls.model.type == FileType.VIRTUAL.value),
  139. cls.model.progress < 1,
  140. cls.model.progress > 0)
  141. return list(docs.dicts())
  142. @classmethod
  143. @DB.connection_context()
  144. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  145. num = cls.model.update(token_num=cls.model.token_num + token_num,
  146. chunk_num=cls.model.chunk_num + chunk_num,
  147. process_duation=cls.model.process_duation + duation).where(
  148. cls.model.id == doc_id).execute()
  149. if num == 0:
  150. raise LookupError(
  151. "Document not found which is supposed to be there")
  152. num = Knowledgebase.update(
  153. token_num=Knowledgebase.token_num +
  154. token_num,
  155. chunk_num=Knowledgebase.chunk_num +
  156. chunk_num).where(
  157. Knowledgebase.id == kb_id).execute()
  158. return num
  159. @classmethod
  160. @DB.connection_context()
  161. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  162. num = cls.model.update(token_num=cls.model.token_num - token_num,
  163. chunk_num=cls.model.chunk_num - chunk_num,
  164. process_duation=cls.model.process_duation + duation).where(
  165. cls.model.id == doc_id).execute()
  166. if num == 0:
  167. raise LookupError(
  168. "Document not found which is supposed to be there")
  169. num = Knowledgebase.update(
  170. token_num=Knowledgebase.token_num -
  171. token_num,
  172. chunk_num=Knowledgebase.chunk_num -
  173. chunk_num
  174. ).where(
  175. Knowledgebase.id == kb_id).execute()
  176. return num
  177. @classmethod
  178. @DB.connection_context()
  179. def clear_chunk_num(cls, doc_id):
  180. doc = cls.model.get_by_id(doc_id)
  181. assert doc, "Can't fine document in database."
  182. num = Knowledgebase.update(
  183. token_num=Knowledgebase.token_num -
  184. doc.token_num,
  185. chunk_num=Knowledgebase.chunk_num -
  186. doc.chunk_num,
  187. doc_num=Knowledgebase.doc_num - 1
  188. ).where(
  189. Knowledgebase.id == doc.kb_id).execute()
  190. return num
  191. @classmethod
  192. @DB.connection_context()
  193. def get_tenant_id(cls, doc_id):
  194. docs = cls.model.select(
  195. Knowledgebase.tenant_id).join(
  196. Knowledgebase, on=(
  197. Knowledgebase.id == cls.model.kb_id)).where(
  198. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  199. docs = docs.dicts()
  200. if not docs:
  201. return
  202. return docs[0]["tenant_id"]
  203. @classmethod
  204. @DB.connection_context()
  205. def get_knowledgebase_id(cls, doc_id):
  206. docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
  207. docs = docs.dicts()
  208. if not docs:
  209. return
  210. return docs[0]["kb_id"]
  211. @classmethod
  212. @DB.connection_context()
  213. def get_tenant_id_by_name(cls, name):
  214. docs = cls.model.select(
  215. Knowledgebase.tenant_id).join(
  216. Knowledgebase, on=(
  217. Knowledgebase.id == cls.model.kb_id)).where(
  218. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  219. docs = docs.dicts()
  220. if not docs:
  221. return
  222. return docs[0]["tenant_id"]
  223. @classmethod
  224. @DB.connection_context()
  225. def accessible(cls, doc_id, user_id):
  226. docs = cls.model.select(
  227. cls.model.id).join(
  228. Knowledgebase, on=(
  229. Knowledgebase.id == cls.model.kb_id)
  230. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  231. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  232. docs = docs.dicts()
  233. if not docs:
  234. return False
  235. return True
  236. @classmethod
  237. @DB.connection_context()
  238. def accessible4deletion(cls, doc_id, user_id):
  239. docs = cls.model.select(
  240. cls.model.id).join(
  241. Knowledgebase, on=(
  242. Knowledgebase.id == cls.model.kb_id)
  243. ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
  244. docs = docs.dicts()
  245. if not docs:
  246. return False
  247. return True
  248. @classmethod
  249. @DB.connection_context()
  250. def get_embd_id(cls, doc_id):
  251. docs = cls.model.select(
  252. Knowledgebase.embd_id).join(
  253. Knowledgebase, on=(
  254. Knowledgebase.id == cls.model.kb_id)).where(
  255. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  256. docs = docs.dicts()
  257. if not docs:
  258. return
  259. return docs[0]["embd_id"]
  260. @classmethod
  261. @DB.connection_context()
  262. def get_doc_id_by_doc_name(cls, doc_name):
  263. fields = [cls.model.id]
  264. doc_id = cls.model.select(*fields) \
  265. .where(cls.model.name == doc_name)
  266. doc_id = doc_id.dicts()
  267. if not doc_id:
  268. return
  269. return doc_id[0]["id"]
  270. @classmethod
  271. @DB.connection_context()
  272. def get_thumbnails(cls, docids):
  273. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  274. return list(cls.model.select(
  275. *fields).where(cls.model.id.in_(docids)).dicts())
  276. @classmethod
  277. @DB.connection_context()
  278. def update_parser_config(cls, id, config):
  279. e, d = cls.get_by_id(id)
  280. if not e:
  281. raise LookupError(f"Document({id}) not found.")
  282. def dfs_update(old, new):
  283. for k, v in new.items():
  284. if k not in old:
  285. old[k] = v
  286. continue
  287. if isinstance(v, dict):
  288. assert isinstance(old[k], dict)
  289. dfs_update(old[k], v)
  290. else:
  291. old[k] = v
  292. dfs_update(d.parser_config, config)
  293. cls.update_by_id(id, {"parser_config": d.parser_config})
  294. @classmethod
  295. @DB.connection_context()
  296. def get_doc_count(cls, tenant_id):
  297. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  298. on=(Knowledgebase.id == cls.model.kb_id)).where(
  299. Knowledgebase.tenant_id == tenant_id)
  300. return len(docs)
  301. @classmethod
  302. @DB.connection_context()
  303. def begin2parse(cls, docid):
  304. cls.update_by_id(
  305. docid, {"progress": random.random() * 1 / 100.,
  306. "progress_msg": "Task is queued...",
  307. "process_begin_at": get_format_time()
  308. })
  309. @classmethod
  310. @DB.connection_context()
  311. def update_progress(cls):
  312. docs = cls.get_unfinished_docs()
  313. for d in docs:
  314. try:
  315. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  316. if not tsks:
  317. continue
  318. msg = []
  319. prg = 0
  320. finished = True
  321. bad = 0
  322. e, doc = DocumentService.get_by_id(d["id"])
  323. status = doc.run # TaskStatus.RUNNING.value
  324. for t in tsks:
  325. if 0 <= t.progress < 1:
  326. finished = False
  327. prg += t.progress if t.progress >= 0 else 0
  328. if t.progress_msg not in msg:
  329. msg.append(t.progress_msg)
  330. if t.progress == -1:
  331. bad += 1
  332. prg /= len(tsks)
  333. if finished and bad:
  334. prg = -1
  335. status = TaskStatus.FAIL.value
  336. elif finished:
  337. if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
  338. " raptor") < 0:
  339. queue_raptor_tasks(d)
  340. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  341. msg.append("------ RAPTOR -------")
  342. else:
  343. status = TaskStatus.DONE.value
  344. msg = "\n".join(msg)
  345. info = {
  346. "process_duation": datetime.timestamp(
  347. datetime.now()) -
  348. d["process_begin_at"].timestamp(),
  349. "run": status}
  350. if prg != 0:
  351. info["progress"] = prg
  352. if msg:
  353. info["progress_msg"] = msg
  354. cls.update_by_id(d["id"], info)
  355. except Exception as e:
  356. if str(e).find("'0'") < 0:
  357. logging.exception("fetch task exception")
  358. @classmethod
  359. @DB.connection_context()
  360. def get_kb_doc_count(cls, kb_id):
  361. return len(cls.model.select(cls.model.id).where(
  362. cls.model.kb_id == kb_id).dicts())
  363. @classmethod
  364. @DB.connection_context()
  365. def do_cancel(cls, doc_id):
  366. try:
  367. _, doc = DocumentService.get_by_id(doc_id)
  368. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  369. except Exception:
  370. pass
  371. return False
  372. def queue_raptor_tasks(doc):
  373. def new_task():
  374. nonlocal doc
  375. return {
  376. "id": get_uuid(),
  377. "doc_id": doc["id"],
  378. "from_page": 0,
  379. "to_page": -1,
  380. "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
  381. }
  382. task = new_task()
  383. bulk_insert_into_db(Task, [task], True)
  384. task["type"] = "raptor"
  385. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  386. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  387. from rag.app import presentation, picture, naive, audio, email
  388. from api.db.services.dialog_service import ConversationService, DialogService
  389. from api.db.services.file_service import FileService
  390. from api.db.services.llm_service import LLMBundle
  391. from api.db.services.user_service import TenantService
  392. from api.db.services.api_service import API4ConversationService
  393. e, conv = ConversationService.get_by_id(conversation_id)
  394. if not e:
  395. e, conv = API4ConversationService.get_by_id(conversation_id)
  396. assert e, "Conversation not found!"
  397. e, dia = DialogService.get_by_id(conv.dialog_id)
  398. kb_id = dia.kb_ids[0]
  399. e, kb = KnowledgebaseService.get_by_id(kb_id)
  400. if not e:
  401. raise LookupError("Can't find this knowledgebase!")
  402. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  403. err, files = FileService.upload_document(kb, file_objs, user_id)
  404. assert not err, "\n".join(err)
  405. def dummy(prog=None, msg=""):
  406. pass
  407. FACTORY = {
  408. ParserType.PRESENTATION.value: presentation,
  409. ParserType.PICTURE.value: picture,
  410. ParserType.AUDIO.value: audio,
  411. ParserType.EMAIL.value: email
  412. }
  413. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  414. exe = ThreadPoolExecutor(max_workers=12)
  415. threads = []
  416. doc_nm = {}
  417. for d, blob in files:
  418. doc_nm[d["id"]] = d["name"]
  419. for d, blob in files:
  420. kwargs = {
  421. "callback": dummy,
  422. "parser_config": parser_config,
  423. "from_page": 0,
  424. "to_page": 100000,
  425. "tenant_id": kb.tenant_id,
  426. "lang": kb.language
  427. }
  428. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  429. for (docinfo, _), th in zip(files, threads):
  430. docs = []
  431. doc = {
  432. "doc_id": docinfo["id"],
  433. "kb_id": [kb.id]
  434. }
  435. for ck in th.result():
  436. d = deepcopy(doc)
  437. d.update(ck)
  438. md5 = hashlib.md5()
  439. md5.update((ck["content_with_weight"] +
  440. str(d["doc_id"])).encode("utf-8"))
  441. d["id"] = md5.hexdigest()
  442. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  443. d["create_timestamp_flt"] = datetime.now().timestamp()
  444. if not d.get("image"):
  445. docs.append(d)
  446. continue
  447. output_buffer = BytesIO()
  448. if isinstance(d["image"], bytes):
  449. output_buffer = BytesIO(d["image"])
  450. else:
  451. d["image"].save(output_buffer, format='JPEG')
  452. STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
  453. d["img_id"] = "{}-{}".format(kb.id, d["id"])
  454. del d["image"]
  455. docs.append(d)
  456. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  457. docids = [d["id"] for d, _ in files]
  458. chunk_counts = {id: 0 for id in docids}
  459. token_counts = {id: 0 for id in docids}
  460. es_bulk_size = 64
  461. def embedding(doc_id, cnts, batch_size=16):
  462. nonlocal embd_mdl, chunk_counts, token_counts
  463. vects = []
  464. for i in range(0, len(cnts), batch_size):
  465. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  466. vects.extend(vts.tolist())
  467. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  468. token_counts[doc_id] += c
  469. return vects
  470. idxnm = search.index_name(kb.tenant_id)
  471. try_create_idx = True
  472. _, tenant = TenantService.get_by_id(kb.tenant_id)
  473. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  474. for doc_id in docids:
  475. cks = [c for c in docs if c["doc_id"] == doc_id]
  476. if parser_ids[doc_id] != ParserType.PICTURE.value:
  477. mindmap = MindMapExtractor(llm_bdl)
  478. try:
  479. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
  480. ensure_ascii=False, indent=2)
  481. if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
  482. cks.append({
  483. "id": get_uuid(),
  484. "doc_id": doc_id,
  485. "kb_id": [kb.id],
  486. "docnm_kwd": doc_nm[doc_id],
  487. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  488. "content_ltks": "",
  489. "content_with_weight": mind_map,
  490. "knowledge_graph_kwd": "mind_map"
  491. })
  492. except Exception as e:
  493. logging.exception("Mind map generation error")
  494. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  495. assert len(cks) == len(vects)
  496. for i, d in enumerate(cks):
  497. v = vects[i]
  498. d["q_%d_vec" % len(v)] = v
  499. for b in range(0, len(cks), es_bulk_size):
  500. if try_create_idx:
  501. if not settings.docStoreConn.indexExist(idxnm, kb_id):
  502. settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
  503. try_create_idx = False
  504. settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
  505. DocumentService.increment_chunk_num(
  506. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  507. return [d["id"] for d, _ in files]