You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

document_service.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import xxhash
  18. import json
  19. import random
  20. import re
  21. from concurrent.futures import ThreadPoolExecutor
  22. from copy import deepcopy
  23. from datetime import datetime
  24. from io import BytesIO
  25. import trio
  26. from peewee import fn
  27. from api.db.db_utils import bulk_insert_into_db
  28. from api import settings
  29. from api.utils import current_timestamp, get_format_time, get_uuid
  30. from graphrag.general.mind_map_extractor import MindMapExtractor
  31. from rag.settings import SVR_QUEUE_NAME
  32. from rag.utils.storage_factory import STORAGE_IMPL
  33. from rag.nlp import search, rag_tokenizer
  34. from api.db import FileType, TaskStatus, ParserType, LLMType
  35. from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
  36. from api.db.db_models import Document
  37. from api.db.services.common_service import CommonService
  38. from api.db.services.knowledgebase_service import KnowledgebaseService
  39. from api.db import StatusEnum
  40. from rag.utils.redis_conn import REDIS_CONN
  41. class DocumentService(CommonService):
  42. model = Document
  43. @classmethod
  44. @DB.connection_context()
  45. def get_list(cls, kb_id, page_number, items_per_page,
  46. orderby, desc, keywords, id, name):
  47. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  48. if id:
  49. docs = docs.where(
  50. cls.model.id == id)
  51. if name:
  52. docs = docs.where(
  53. cls.model.name == name
  54. )
  55. if keywords:
  56. docs = docs.where(
  57. fn.LOWER(cls.model.name).contains(keywords.lower())
  58. )
  59. if desc:
  60. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  61. else:
  62. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  63. count = docs.count()
  64. docs = docs.paginate(page_number, items_per_page)
  65. return list(docs.dicts()), count
  66. @classmethod
  67. @DB.connection_context()
  68. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  69. orderby, desc, keywords):
  70. if keywords:
  71. docs = cls.model.select().where(
  72. (cls.model.kb_id == kb_id),
  73. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  74. )
  75. else:
  76. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  77. count = docs.count()
  78. if desc:
  79. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  80. else:
  81. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  82. docs = docs.paginate(page_number, items_per_page)
  83. return list(docs.dicts()), count
  84. @classmethod
  85. @DB.connection_context()
  86. def insert(cls, doc):
  87. if not cls.save(**doc):
  88. raise RuntimeError("Database error (Document)!")
  89. e, kb = KnowledgebaseService.get_by_id(doc["kb_id"])
  90. if not KnowledgebaseService.update_by_id(
  91. kb.id, {"doc_num": kb.doc_num + 1}):
  92. raise RuntimeError("Database error (Knowledgebase)!")
  93. return Document(**doc)
  94. @classmethod
  95. @DB.connection_context()
  96. def remove_document(cls, doc, tenant_id):
  97. cls.clear_chunk_num(doc.id)
  98. try:
  99. settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
  100. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
  101. {"remove": {"source_id": doc.id}},
  102. search.index_name(tenant_id), doc.kb_id)
  103. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
  104. {"removed_kwd": "Y"},
  105. search.index_name(tenant_id), doc.kb_id)
  106. settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
  107. search.index_name(tenant_id), doc.kb_id)
  108. except Exception:
  109. pass
  110. return cls.delete_by_id(doc.id)
  111. @classmethod
  112. @DB.connection_context()
  113. def get_newly_uploaded(cls):
  114. fields = [
  115. cls.model.id,
  116. cls.model.kb_id,
  117. cls.model.parser_id,
  118. cls.model.parser_config,
  119. cls.model.name,
  120. cls.model.type,
  121. cls.model.location,
  122. cls.model.size,
  123. Knowledgebase.tenant_id,
  124. Tenant.embd_id,
  125. Tenant.img2txt_id,
  126. Tenant.asr_id,
  127. cls.model.update_time]
  128. docs = cls.model.select(*fields) \
  129. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  130. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  131. .where(
  132. cls.model.status == StatusEnum.VALID.value,
  133. ~(cls.model.type == FileType.VIRTUAL.value),
  134. cls.model.progress == 0,
  135. cls.model.update_time >= current_timestamp() - 1000 * 600,
  136. cls.model.run == TaskStatus.RUNNING.value) \
  137. .order_by(cls.model.update_time.asc())
  138. return list(docs.dicts())
  139. @classmethod
  140. @DB.connection_context()
  141. def get_unfinished_docs(cls):
  142. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  143. cls.model.run, cls.model.parser_id]
  144. docs = cls.model.select(*fields) \
  145. .where(
  146. cls.model.status == StatusEnum.VALID.value,
  147. ~(cls.model.type == FileType.VIRTUAL.value),
  148. cls.model.progress < 1,
  149. cls.model.progress > 0)
  150. return list(docs.dicts())
  151. @classmethod
  152. @DB.connection_context()
  153. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  154. num = cls.model.update(token_num=cls.model.token_num + token_num,
  155. chunk_num=cls.model.chunk_num + chunk_num,
  156. process_duation=cls.model.process_duation + duation).where(
  157. cls.model.id == doc_id).execute()
  158. if num == 0:
  159. raise LookupError(
  160. "Document not found which is supposed to be there")
  161. num = Knowledgebase.update(
  162. token_num=Knowledgebase.token_num +
  163. token_num,
  164. chunk_num=Knowledgebase.chunk_num +
  165. chunk_num).where(
  166. Knowledgebase.id == kb_id).execute()
  167. return num
  168. @classmethod
  169. @DB.connection_context()
  170. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  171. num = cls.model.update(token_num=cls.model.token_num - token_num,
  172. chunk_num=cls.model.chunk_num - chunk_num,
  173. process_duation=cls.model.process_duation + duation).where(
  174. cls.model.id == doc_id).execute()
  175. if num == 0:
  176. raise LookupError(
  177. "Document not found which is supposed to be there")
  178. num = Knowledgebase.update(
  179. token_num=Knowledgebase.token_num -
  180. token_num,
  181. chunk_num=Knowledgebase.chunk_num -
  182. chunk_num
  183. ).where(
  184. Knowledgebase.id == kb_id).execute()
  185. return num
  186. @classmethod
  187. @DB.connection_context()
  188. def clear_chunk_num(cls, doc_id):
  189. doc = cls.model.get_by_id(doc_id)
  190. assert doc, "Can't fine document in database."
  191. num = Knowledgebase.update(
  192. token_num=Knowledgebase.token_num -
  193. doc.token_num,
  194. chunk_num=Knowledgebase.chunk_num -
  195. doc.chunk_num,
  196. doc_num=Knowledgebase.doc_num - 1
  197. ).where(
  198. Knowledgebase.id == doc.kb_id).execute()
  199. return num
  200. @classmethod
  201. @DB.connection_context()
  202. def get_tenant_id(cls, doc_id):
  203. docs = cls.model.select(
  204. Knowledgebase.tenant_id).join(
  205. Knowledgebase, on=(
  206. Knowledgebase.id == cls.model.kb_id)).where(
  207. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  208. docs = docs.dicts()
  209. if not docs:
  210. return
  211. return docs[0]["tenant_id"]
  212. @classmethod
  213. @DB.connection_context()
  214. def get_knowledgebase_id(cls, doc_id):
  215. docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
  216. docs = docs.dicts()
  217. if not docs:
  218. return
  219. return docs[0]["kb_id"]
  220. @classmethod
  221. @DB.connection_context()
  222. def get_tenant_id_by_name(cls, name):
  223. docs = cls.model.select(
  224. Knowledgebase.tenant_id).join(
  225. Knowledgebase, on=(
  226. Knowledgebase.id == cls.model.kb_id)).where(
  227. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  228. docs = docs.dicts()
  229. if not docs:
  230. return
  231. return docs[0]["tenant_id"]
  232. @classmethod
  233. @DB.connection_context()
  234. def accessible(cls, doc_id, user_id):
  235. docs = cls.model.select(
  236. cls.model.id).join(
  237. Knowledgebase, on=(
  238. Knowledgebase.id == cls.model.kb_id)
  239. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  240. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  241. docs = docs.dicts()
  242. if not docs:
  243. return False
  244. return True
  245. @classmethod
  246. @DB.connection_context()
  247. def accessible4deletion(cls, doc_id, user_id):
  248. docs = cls.model.select(
  249. cls.model.id).join(
  250. Knowledgebase, on=(
  251. Knowledgebase.id == cls.model.kb_id)
  252. ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
  253. docs = docs.dicts()
  254. if not docs:
  255. return False
  256. return True
  257. @classmethod
  258. @DB.connection_context()
  259. def get_embd_id(cls, doc_id):
  260. docs = cls.model.select(
  261. Knowledgebase.embd_id).join(
  262. Knowledgebase, on=(
  263. Knowledgebase.id == cls.model.kb_id)).where(
  264. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  265. docs = docs.dicts()
  266. if not docs:
  267. return
  268. return docs[0]["embd_id"]
  269. @classmethod
  270. @DB.connection_context()
  271. def get_chunking_config(cls, doc_id):
  272. configs = (
  273. cls.model.select(
  274. cls.model.id,
  275. cls.model.kb_id,
  276. cls.model.parser_id,
  277. cls.model.parser_config,
  278. Knowledgebase.language,
  279. Knowledgebase.embd_id,
  280. Tenant.id.alias("tenant_id"),
  281. Tenant.img2txt_id,
  282. Tenant.asr_id,
  283. Tenant.llm_id,
  284. )
  285. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
  286. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
  287. .where(cls.model.id == doc_id)
  288. )
  289. configs = configs.dicts()
  290. if not configs:
  291. return None
  292. return configs[0]
  293. @classmethod
  294. @DB.connection_context()
  295. def get_doc_id_by_doc_name(cls, doc_name):
  296. fields = [cls.model.id]
  297. doc_id = cls.model.select(*fields) \
  298. .where(cls.model.name == doc_name)
  299. doc_id = doc_id.dicts()
  300. if not doc_id:
  301. return
  302. return doc_id[0]["id"]
  303. @classmethod
  304. @DB.connection_context()
  305. def get_thumbnails(cls, docids):
  306. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  307. return list(cls.model.select(
  308. *fields).where(cls.model.id.in_(docids)).dicts())
  309. @classmethod
  310. @DB.connection_context()
  311. def update_parser_config(cls, id, config):
  312. e, d = cls.get_by_id(id)
  313. if not e:
  314. raise LookupError(f"Document({id}) not found.")
  315. def dfs_update(old, new):
  316. for k, v in new.items():
  317. if k not in old:
  318. old[k] = v
  319. continue
  320. if isinstance(v, dict):
  321. assert isinstance(old[k], dict)
  322. dfs_update(old[k], v)
  323. else:
  324. old[k] = v
  325. dfs_update(d.parser_config, config)
  326. if not config.get("raptor") and d.parser_config.get("raptor"):
  327. del d.parser_config["raptor"]
  328. cls.update_by_id(id, {"parser_config": d.parser_config})
  329. @classmethod
  330. @DB.connection_context()
  331. def get_doc_count(cls, tenant_id):
  332. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  333. on=(Knowledgebase.id == cls.model.kb_id)).where(
  334. Knowledgebase.tenant_id == tenant_id)
  335. return len(docs)
  336. @classmethod
  337. @DB.connection_context()
  338. def begin2parse(cls, docid):
  339. cls.update_by_id(
  340. docid, {"progress": random.random() * 1 / 100.,
  341. "progress_msg": "Task is queued...",
  342. "process_begin_at": get_format_time()
  343. })
  344. @classmethod
  345. @DB.connection_context()
  346. def update_meta_fields(cls, doc_id, meta_fields):
  347. return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
  348. @classmethod
  349. @DB.connection_context()
  350. def update_progress(cls):
  351. MSG = {
  352. "raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
  353. "graphrag": "Entities",
  354. "graph_resolution": "Resolution",
  355. "graph_community": "Communities"
  356. }
  357. docs = cls.get_unfinished_docs()
  358. for d in docs:
  359. try:
  360. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  361. if not tsks:
  362. continue
  363. msg = []
  364. prg = 0
  365. finished = True
  366. bad = 0
  367. e, doc = DocumentService.get_by_id(d["id"])
  368. status = doc.run # TaskStatus.RUNNING.value
  369. for t in tsks:
  370. if 0 <= t.progress < 1:
  371. finished = False
  372. prg += t.progress if t.progress >= 0 else 0
  373. if t.progress_msg not in msg:
  374. msg.append(t.progress_msg)
  375. if t.progress == -1:
  376. bad += 1
  377. prg /= len(tsks)
  378. if finished and bad:
  379. prg = -1
  380. status = TaskStatus.FAIL.value
  381. elif finished:
  382. m = "\n".join(sorted(msg))
  383. if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
  384. queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
  385. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  386. elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
  387. queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
  388. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  389. elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
  390. and d["parser_config"].get("graphrag", {}).get("resolution") \
  391. and m.find(MSG["graph_resolution"]) < 0:
  392. queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
  393. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  394. elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
  395. and d["parser_config"].get("graphrag", {}).get("community") \
  396. and m.find(MSG["graph_community"]) < 0:
  397. queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
  398. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  399. else:
  400. status = TaskStatus.DONE.value
  401. msg = "\n".join(sorted(msg))
  402. info = {
  403. "process_duation": datetime.timestamp(
  404. datetime.now()) -
  405. d["process_begin_at"].timestamp(),
  406. "run": status}
  407. if prg != 0:
  408. info["progress"] = prg
  409. if msg:
  410. info["progress_msg"] = msg
  411. cls.update_by_id(d["id"], info)
  412. except Exception as e:
  413. if str(e).find("'0'") < 0:
  414. logging.exception("fetch task exception")
  415. @classmethod
  416. @DB.connection_context()
  417. def get_kb_doc_count(cls, kb_id):
  418. return len(cls.model.select(cls.model.id).where(
  419. cls.model.kb_id == kb_id).dicts())
  420. @classmethod
  421. @DB.connection_context()
  422. def do_cancel(cls, doc_id):
  423. try:
  424. _, doc = DocumentService.get_by_id(doc_id)
  425. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  426. except Exception:
  427. pass
  428. return False
  429. def queue_raptor_o_graphrag_tasks(doc, ty, msg):
  430. chunking_config = DocumentService.get_chunking_config(doc["id"])
  431. hasher = xxhash.xxh64()
  432. for field in sorted(chunking_config.keys()):
  433. hasher.update(str(chunking_config[field]).encode("utf-8"))
  434. def new_task():
  435. nonlocal doc
  436. return {
  437. "id": get_uuid(),
  438. "doc_id": doc["id"],
  439. "from_page": 100000000,
  440. "to_page": 100000000,
  441. "progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
  442. }
  443. task = new_task()
  444. for field in ["doc_id", "from_page", "to_page"]:
  445. hasher.update(str(task.get(field, "")).encode("utf-8"))
  446. hasher.update(ty.encode("utf-8"))
  447. task["digest"] = hasher.hexdigest()
  448. bulk_insert_into_db(Task, [task], True)
  449. task["task_type"] = ty
  450. assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
  451. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  452. from rag.app import presentation, picture, naive, audio, email
  453. from api.db.services.dialog_service import DialogService
  454. from api.db.services.file_service import FileService
  455. from api.db.services.llm_service import LLMBundle
  456. from api.db.services.user_service import TenantService
  457. from api.db.services.api_service import API4ConversationService
  458. from api.db.services.conversation_service import ConversationService
  459. e, conv = ConversationService.get_by_id(conversation_id)
  460. if not e:
  461. e, conv = API4ConversationService.get_by_id(conversation_id)
  462. assert e, "Conversation not found!"
  463. e, dia = DialogService.get_by_id(conv.dialog_id)
  464. if not dia.kb_ids:
  465. raise LookupError("No knowledge base associated with this conversation. "
  466. "Please add a knowledge base before uploading documents")
  467. kb_id = dia.kb_ids[0]
  468. e, kb = KnowledgebaseService.get_by_id(kb_id)
  469. if not e:
  470. raise LookupError("Can't find this knowledgebase!")
  471. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  472. err, files = FileService.upload_document(kb, file_objs, user_id)
  473. assert not err, "\n".join(err)
  474. def dummy(prog=None, msg=""):
  475. pass
  476. FACTORY = {
  477. ParserType.PRESENTATION.value: presentation,
  478. ParserType.PICTURE.value: picture,
  479. ParserType.AUDIO.value: audio,
  480. ParserType.EMAIL.value: email
  481. }
  482. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
  483. exe = ThreadPoolExecutor(max_workers=12)
  484. threads = []
  485. doc_nm = {}
  486. for d, blob in files:
  487. doc_nm[d["id"]] = d["name"]
  488. for d, blob in files:
  489. kwargs = {
  490. "callback": dummy,
  491. "parser_config": parser_config,
  492. "from_page": 0,
  493. "to_page": 100000,
  494. "tenant_id": kb.tenant_id,
  495. "lang": kb.language
  496. }
  497. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  498. for (docinfo, _), th in zip(files, threads):
  499. docs = []
  500. doc = {
  501. "doc_id": docinfo["id"],
  502. "kb_id": [kb.id]
  503. }
  504. for ck in th.result():
  505. d = deepcopy(doc)
  506. d.update(ck)
  507. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  508. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  509. d["create_timestamp_flt"] = datetime.now().timestamp()
  510. if not d.get("image"):
  511. docs.append(d)
  512. continue
  513. output_buffer = BytesIO()
  514. if isinstance(d["image"], bytes):
  515. output_buffer = BytesIO(d["image"])
  516. else:
  517. d["image"].save(output_buffer, format='JPEG')
  518. STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
  519. d["img_id"] = "{}-{}".format(kb.id, d["id"])
  520. d.pop("image", None)
  521. docs.append(d)
  522. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  523. docids = [d["id"] for d, _ in files]
  524. chunk_counts = {id: 0 for id in docids}
  525. token_counts = {id: 0 for id in docids}
  526. es_bulk_size = 64
  527. def embedding(doc_id, cnts, batch_size=16):
  528. nonlocal embd_mdl, chunk_counts, token_counts
  529. vects = []
  530. for i in range(0, len(cnts), batch_size):
  531. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  532. vects.extend(vts.tolist())
  533. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  534. token_counts[doc_id] += c
  535. return vects
  536. idxnm = search.index_name(kb.tenant_id)
  537. try_create_idx = True
  538. _, tenant = TenantService.get_by_id(kb.tenant_id)
  539. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  540. for doc_id in docids:
  541. cks = [c for c in docs if c["doc_id"] == doc_id]
  542. if parser_ids[doc_id] != ParserType.PICTURE.value:
  543. mindmap = MindMapExtractor(llm_bdl)
  544. try:
  545. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
  546. mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
  547. if len(mind_map) < 32:
  548. raise Exception("Few content: " + mind_map)
  549. cks.append({
  550. "id": get_uuid(),
  551. "doc_id": doc_id,
  552. "kb_id": [kb.id],
  553. "docnm_kwd": doc_nm[doc_id],
  554. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  555. "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
  556. "content_with_weight": mind_map,
  557. "knowledge_graph_kwd": "mind_map"
  558. })
  559. except Exception as e:
  560. logging.exception("Mind map generation error")
  561. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  562. assert len(cks) == len(vects)
  563. for i, d in enumerate(cks):
  564. v = vects[i]
  565. d["q_%d_vec" % len(v)] = v
  566. for b in range(0, len(cks), es_bulk_size):
  567. if try_create_idx:
  568. if not settings.docStoreConn.indexExist(idxnm, kb_id):
  569. settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
  570. try_create_idx = False
  571. settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
  572. DocumentService.increment_chunk_num(
  573. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  574. return [d["id"] for d, _ in files]