Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

document_service.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import random
  19. import re
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from datetime import datetime
  23. from io import BytesIO
  24. import trio
  25. import xxhash
  26. from peewee import fn
  27. from api import settings
  28. from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
  29. from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant
  30. from api.db.db_utils import bulk_insert_into_db
  31. from api.db.services.common_service import CommonService
  32. from api.db.services.knowledgebase_service import KnowledgebaseService
  33. from api.utils import current_timestamp, get_format_time, get_uuid
  34. from rag.nlp import rag_tokenizer, search
  35. from rag.settings import get_svr_queue_name
  36. from rag.utils.redis_conn import REDIS_CONN
  37. from rag.utils.storage_factory import STORAGE_IMPL
  38. class DocumentService(CommonService):
  39. model = Document
  40. @classmethod
  41. @DB.connection_context()
  42. def get_list(cls, kb_id, page_number, items_per_page,
  43. orderby, desc, keywords, id, name):
  44. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  45. if id:
  46. docs = docs.where(
  47. cls.model.id == id)
  48. if name:
  49. docs = docs.where(
  50. cls.model.name == name
  51. )
  52. if keywords:
  53. docs = docs.where(
  54. fn.LOWER(cls.model.name).contains(keywords.lower())
  55. )
  56. if desc:
  57. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  58. else:
  59. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  60. count = docs.count()
  61. docs = docs.paginate(page_number, items_per_page)
  62. return list(docs.dicts()), count
  63. @classmethod
  64. @DB.connection_context()
  65. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  66. orderby, desc, keywords):
  67. if keywords:
  68. docs = cls.model.select().where(
  69. (cls.model.kb_id == kb_id),
  70. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  71. )
  72. else:
  73. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  74. count = docs.count()
  75. if desc:
  76. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  77. else:
  78. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  79. docs = docs.paginate(page_number, items_per_page)
  80. return list(docs.dicts()), count
  81. @classmethod
  82. @DB.connection_context()
  83. def insert(cls, doc):
  84. if not cls.save(**doc):
  85. raise RuntimeError("Database error (Document)!")
  86. if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
  87. raise RuntimeError("Database error (Knowledgebase)!")
  88. return Document(**doc)
  89. @classmethod
  90. @DB.connection_context()
  91. def remove_document(cls, doc, tenant_id):
  92. cls.clear_chunk_num(doc.id)
  93. try:
  94. settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
  95. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
  96. {"remove": {"source_id": doc.id}},
  97. search.index_name(tenant_id), doc.kb_id)
  98. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
  99. {"removed_kwd": "Y"},
  100. search.index_name(tenant_id), doc.kb_id)
  101. settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
  102. search.index_name(tenant_id), doc.kb_id)
  103. except Exception:
  104. pass
  105. return cls.delete_by_id(doc.id)
  106. @classmethod
  107. @DB.connection_context()
  108. def get_newly_uploaded(cls):
  109. fields = [
  110. cls.model.id,
  111. cls.model.kb_id,
  112. cls.model.parser_id,
  113. cls.model.parser_config,
  114. cls.model.name,
  115. cls.model.type,
  116. cls.model.location,
  117. cls.model.size,
  118. Knowledgebase.tenant_id,
  119. Tenant.embd_id,
  120. Tenant.img2txt_id,
  121. Tenant.asr_id,
  122. cls.model.update_time]
  123. docs = cls.model.select(*fields) \
  124. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  125. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  126. .where(
  127. cls.model.status == StatusEnum.VALID.value,
  128. ~(cls.model.type == FileType.VIRTUAL.value),
  129. cls.model.progress == 0,
  130. cls.model.update_time >= current_timestamp() - 1000 * 600,
  131. cls.model.run == TaskStatus.RUNNING.value) \
  132. .order_by(cls.model.update_time.asc())
  133. return list(docs.dicts())
  134. @classmethod
  135. @DB.connection_context()
  136. def get_unfinished_docs(cls):
  137. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  138. cls.model.run, cls.model.parser_id]
  139. docs = cls.model.select(*fields) \
  140. .where(
  141. cls.model.status == StatusEnum.VALID.value,
  142. ~(cls.model.type == FileType.VIRTUAL.value),
  143. cls.model.progress < 1,
  144. cls.model.progress > 0)
  145. return list(docs.dicts())
  146. @classmethod
  147. @DB.connection_context()
  148. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  149. num = cls.model.update(token_num=cls.model.token_num + token_num,
  150. chunk_num=cls.model.chunk_num + chunk_num,
  151. process_duation=cls.model.process_duation + duation).where(
  152. cls.model.id == doc_id).execute()
  153. if num == 0:
  154. raise LookupError(
  155. "Document not found which is supposed to be there")
  156. num = Knowledgebase.update(
  157. token_num=Knowledgebase.token_num +
  158. token_num,
  159. chunk_num=Knowledgebase.chunk_num +
  160. chunk_num).where(
  161. Knowledgebase.id == kb_id).execute()
  162. return num
  163. @classmethod
  164. @DB.connection_context()
  165. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  166. num = cls.model.update(token_num=cls.model.token_num - token_num,
  167. chunk_num=cls.model.chunk_num - chunk_num,
  168. process_duation=cls.model.process_duation + duation).where(
  169. cls.model.id == doc_id).execute()
  170. if num == 0:
  171. raise LookupError(
  172. "Document not found which is supposed to be there")
  173. num = Knowledgebase.update(
  174. token_num=Knowledgebase.token_num -
  175. token_num,
  176. chunk_num=Knowledgebase.chunk_num -
  177. chunk_num
  178. ).where(
  179. Knowledgebase.id == kb_id).execute()
  180. return num
  181. @classmethod
  182. @DB.connection_context()
  183. def clear_chunk_num(cls, doc_id):
  184. doc = cls.model.get_by_id(doc_id)
  185. assert doc, "Can't fine document in database."
  186. num = Knowledgebase.update(
  187. token_num=Knowledgebase.token_num -
  188. doc.token_num,
  189. chunk_num=Knowledgebase.chunk_num -
  190. doc.chunk_num,
  191. doc_num=Knowledgebase.doc_num - 1
  192. ).where(
  193. Knowledgebase.id == doc.kb_id).execute()
  194. return num
  195. @classmethod
  196. @DB.connection_context()
  197. def get_tenant_id(cls, doc_id):
  198. docs = cls.model.select(
  199. Knowledgebase.tenant_id).join(
  200. Knowledgebase, on=(
  201. Knowledgebase.id == cls.model.kb_id)).where(
  202. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  203. docs = docs.dicts()
  204. if not docs:
  205. return
  206. return docs[0]["tenant_id"]
  207. @classmethod
  208. @DB.connection_context()
  209. def get_knowledgebase_id(cls, doc_id):
  210. docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
  211. docs = docs.dicts()
  212. if not docs:
  213. return
  214. return docs[0]["kb_id"]
  215. @classmethod
  216. @DB.connection_context()
  217. def get_tenant_id_by_name(cls, name):
  218. docs = cls.model.select(
  219. Knowledgebase.tenant_id).join(
  220. Knowledgebase, on=(
  221. Knowledgebase.id == cls.model.kb_id)).where(
  222. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  223. docs = docs.dicts()
  224. if not docs:
  225. return
  226. return docs[0]["tenant_id"]
  227. @classmethod
  228. @DB.connection_context()
  229. def accessible(cls, doc_id, user_id):
  230. docs = cls.model.select(
  231. cls.model.id).join(
  232. Knowledgebase, on=(
  233. Knowledgebase.id == cls.model.kb_id)
  234. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  235. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  236. docs = docs.dicts()
  237. if not docs:
  238. return False
  239. return True
  240. @classmethod
  241. @DB.connection_context()
  242. def accessible4deletion(cls, doc_id, user_id):
  243. docs = cls.model.select(cls.model.id
  244. ).join(
  245. Knowledgebase, on=(
  246. Knowledgebase.id == cls.model.kb_id)
  247. ).join(
  248. UserTenant, on=(
  249. (UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
  250. ).where(
  251. cls.model.id == doc_id,
  252. UserTenant.status == StatusEnum.VALID.value,
  253. ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
  254. ).paginate(0, 1)
  255. docs = docs.dicts()
  256. if not docs:
  257. return False
  258. return True
  259. @classmethod
  260. @DB.connection_context()
  261. def get_embd_id(cls, doc_id):
  262. docs = cls.model.select(
  263. Knowledgebase.embd_id).join(
  264. Knowledgebase, on=(
  265. Knowledgebase.id == cls.model.kb_id)).where(
  266. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  267. docs = docs.dicts()
  268. if not docs:
  269. return
  270. return docs[0]["embd_id"]
  271. @classmethod
  272. @DB.connection_context()
  273. def get_chunking_config(cls, doc_id):
  274. configs = (
  275. cls.model.select(
  276. cls.model.id,
  277. cls.model.kb_id,
  278. cls.model.parser_id,
  279. cls.model.parser_config,
  280. Knowledgebase.language,
  281. Knowledgebase.embd_id,
  282. Tenant.id.alias("tenant_id"),
  283. Tenant.img2txt_id,
  284. Tenant.asr_id,
  285. Tenant.llm_id,
  286. )
  287. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
  288. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
  289. .where(cls.model.id == doc_id)
  290. )
  291. configs = configs.dicts()
  292. if not configs:
  293. return None
  294. return configs[0]
  295. @classmethod
  296. @DB.connection_context()
  297. def get_doc_id_by_doc_name(cls, doc_name):
  298. fields = [cls.model.id]
  299. doc_id = cls.model.select(*fields) \
  300. .where(cls.model.name == doc_name)
  301. doc_id = doc_id.dicts()
  302. if not doc_id:
  303. return
  304. return doc_id[0]["id"]
  305. @classmethod
  306. @DB.connection_context()
  307. def get_thumbnails(cls, docids):
  308. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  309. return list(cls.model.select(
  310. *fields).where(cls.model.id.in_(docids)).dicts())
  311. @classmethod
  312. @DB.connection_context()
  313. def update_parser_config(cls, id, config):
  314. if not config:
  315. return
  316. e, d = cls.get_by_id(id)
  317. if not e:
  318. raise LookupError(f"Document({id}) not found.")
  319. def dfs_update(old, new):
  320. for k, v in new.items():
  321. if k not in old:
  322. old[k] = v
  323. continue
  324. if isinstance(v, dict):
  325. assert isinstance(old[k], dict)
  326. dfs_update(old[k], v)
  327. else:
  328. old[k] = v
  329. dfs_update(d.parser_config, config)
  330. if not config.get("raptor") and d.parser_config.get("raptor"):
  331. del d.parser_config["raptor"]
  332. cls.update_by_id(id, {"parser_config": d.parser_config})
  333. @classmethod
  334. @DB.connection_context()
  335. def get_doc_count(cls, tenant_id):
  336. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  337. on=(Knowledgebase.id == cls.model.kb_id)).where(
  338. Knowledgebase.tenant_id == tenant_id)
  339. return len(docs)
  340. @classmethod
  341. @DB.connection_context()
  342. def begin2parse(cls, docid):
  343. cls.update_by_id(
  344. docid, {"progress": random.random() * 1 / 100.,
  345. "progress_msg": "Task is queued...",
  346. "process_begin_at": get_format_time()
  347. })
  348. @classmethod
  349. @DB.connection_context()
  350. def update_meta_fields(cls, doc_id, meta_fields):
  351. return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
  352. @classmethod
  353. @DB.connection_context()
  354. def update_progress(cls):
  355. docs = cls.get_unfinished_docs()
  356. for d in docs:
  357. try:
  358. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  359. if not tsks:
  360. continue
  361. msg = []
  362. prg = 0
  363. finished = True
  364. bad = 0
  365. has_raptor = False
  366. has_graphrag = False
  367. e, doc = DocumentService.get_by_id(d["id"])
  368. status = doc.run # TaskStatus.RUNNING.value
  369. priority = 0
  370. for t in tsks:
  371. if 0 <= t.progress < 1:
  372. finished = False
  373. if t.progress == -1:
  374. bad += 1
  375. prg += t.progress if t.progress >= 0 else 0
  376. msg.append(t.progress_msg)
  377. if t.task_type == "raptor":
  378. has_raptor = True
  379. elif t.task_type == "graphrag":
  380. has_graphrag = True
  381. priority = max(priority, t.priority)
  382. prg /= len(tsks)
  383. if finished and bad:
  384. prg = -1
  385. status = TaskStatus.FAIL.value
  386. elif finished:
  387. if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
  388. queue_raptor_o_graphrag_tasks(d, "raptor", priority)
  389. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  390. elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
  391. queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
  392. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  393. else:
  394. status = TaskStatus.DONE.value
  395. msg = "\n".join(sorted(msg))
  396. info = {
  397. "process_duation": datetime.timestamp(
  398. datetime.now()) -
  399. d["process_begin_at"].timestamp(),
  400. "run": status}
  401. if prg != 0:
  402. info["progress"] = prg
  403. if msg:
  404. info["progress_msg"] = msg
  405. cls.update_by_id(d["id"], info)
  406. except Exception as e:
  407. if str(e).find("'0'") < 0:
  408. logging.exception("fetch task exception")
  409. @classmethod
  410. @DB.connection_context()
  411. def get_kb_doc_count(cls, kb_id):
  412. return len(cls.model.select(cls.model.id).where(
  413. cls.model.kb_id == kb_id).dicts())
  414. @classmethod
  415. @DB.connection_context()
  416. def do_cancel(cls, doc_id):
  417. try:
  418. _, doc = DocumentService.get_by_id(doc_id)
  419. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  420. except Exception:
  421. pass
  422. return False
  423. def queue_raptor_o_graphrag_tasks(doc, ty, priority):
  424. chunking_config = DocumentService.get_chunking_config(doc["id"])
  425. hasher = xxhash.xxh64()
  426. for field in sorted(chunking_config.keys()):
  427. hasher.update(str(chunking_config[field]).encode("utf-8"))
  428. def new_task():
  429. nonlocal doc
  430. return {
  431. "id": get_uuid(),
  432. "doc_id": doc["id"],
  433. "from_page": 100000000,
  434. "to_page": 100000000,
  435. "task_type": ty,
  436. "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
  437. }
  438. task = new_task()
  439. for field in ["doc_id", "from_page", "to_page"]:
  440. hasher.update(str(task.get(field, "")).encode("utf-8"))
  441. hasher.update(ty.encode("utf-8"))
  442. task["digest"] = hasher.hexdigest()
  443. bulk_insert_into_db(Task, [task], True)
  444. assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
  445. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  446. from api.db.services.api_service import API4ConversationService
  447. from api.db.services.conversation_service import ConversationService
  448. from api.db.services.dialog_service import DialogService
  449. from api.db.services.file_service import FileService
  450. from api.db.services.llm_service import LLMBundle
  451. from api.db.services.user_service import TenantService
  452. from rag.app import audio, email, naive, picture, presentation
  453. e, conv = ConversationService.get_by_id(conversation_id)
  454. if not e:
  455. e, conv = API4ConversationService.get_by_id(conversation_id)
  456. assert e, "Conversation not found!"
  457. e, dia = DialogService.get_by_id(conv.dialog_id)
  458. if not dia.kb_ids:
  459. raise LookupError("No knowledge base associated with this conversation. "
  460. "Please add a knowledge base before uploading documents")
  461. kb_id = dia.kb_ids[0]
  462. e, kb = KnowledgebaseService.get_by_id(kb_id)
  463. if not e:
  464. raise LookupError("Can't find this knowledgebase!")
  465. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  466. err, files = FileService.upload_document(kb, file_objs, user_id)
  467. assert not err, "\n".join(err)
  468. def dummy(prog=None, msg=""):
  469. pass
  470. FACTORY = {
  471. ParserType.PRESENTATION.value: presentation,
  472. ParserType.PICTURE.value: picture,
  473. ParserType.AUDIO.value: audio,
  474. ParserType.EMAIL.value: email
  475. }
  476. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
  477. exe = ThreadPoolExecutor(max_workers=12)
  478. threads = []
  479. doc_nm = {}
  480. for d, blob in files:
  481. doc_nm[d["id"]] = d["name"]
  482. for d, blob in files:
  483. kwargs = {
  484. "callback": dummy,
  485. "parser_config": parser_config,
  486. "from_page": 0,
  487. "to_page": 100000,
  488. "tenant_id": kb.tenant_id,
  489. "lang": kb.language
  490. }
  491. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  492. for (docinfo, _), th in zip(files, threads):
  493. docs = []
  494. doc = {
  495. "doc_id": docinfo["id"],
  496. "kb_id": [kb.id]
  497. }
  498. for ck in th.result():
  499. d = deepcopy(doc)
  500. d.update(ck)
  501. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  502. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  503. d["create_timestamp_flt"] = datetime.now().timestamp()
  504. if not d.get("image"):
  505. docs.append(d)
  506. continue
  507. output_buffer = BytesIO()
  508. if isinstance(d["image"], bytes):
  509. output_buffer = BytesIO(d["image"])
  510. else:
  511. d["image"].save(output_buffer, format='JPEG')
  512. STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
  513. d["img_id"] = "{}-{}".format(kb.id, d["id"])
  514. d.pop("image", None)
  515. docs.append(d)
  516. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  517. docids = [d["id"] for d, _ in files]
  518. chunk_counts = {id: 0 for id in docids}
  519. token_counts = {id: 0 for id in docids}
  520. es_bulk_size = 64
  521. def embedding(doc_id, cnts, batch_size=16):
  522. nonlocal embd_mdl, chunk_counts, token_counts
  523. vects = []
  524. for i in range(0, len(cnts), batch_size):
  525. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  526. vects.extend(vts.tolist())
  527. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  528. token_counts[doc_id] += c
  529. return vects
  530. idxnm = search.index_name(kb.tenant_id)
  531. try_create_idx = True
  532. _, tenant = TenantService.get_by_id(kb.tenant_id)
  533. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  534. for doc_id in docids:
  535. cks = [c for c in docs if c["doc_id"] == doc_id]
  536. if parser_ids[doc_id] != ParserType.PICTURE.value:
  537. from graphrag.general.mind_map_extractor import MindMapExtractor
  538. mindmap = MindMapExtractor(llm_bdl)
  539. try:
  540. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
  541. mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
  542. if len(mind_map) < 32:
  543. raise Exception("Few content: " + mind_map)
  544. cks.append({
  545. "id": get_uuid(),
  546. "doc_id": doc_id,
  547. "kb_id": [kb.id],
  548. "docnm_kwd": doc_nm[doc_id],
  549. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  550. "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
  551. "content_with_weight": mind_map,
  552. "knowledge_graph_kwd": "mind_map"
  553. })
  554. except Exception as e:
  555. logging.exception("Mind map generation error")
  556. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  557. assert len(cks) == len(vects)
  558. for i, d in enumerate(cks):
  559. v = vects[i]
  560. d["q_%d_vec" % len(v)] = v
  561. for b in range(0, len(cks), es_bulk_size):
  562. if try_create_idx:
  563. if not settings.docStoreConn.indexExist(idxnm, kb_id):
  564. settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
  565. try_create_idx = False
  566. settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
  567. DocumentService.increment_chunk_num(
  568. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  569. return [d["id"] for d, _ in files]