您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

document_service.py 25KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import random
  19. import re
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from datetime import datetime
  23. from io import BytesIO
  24. import trio
  25. import xxhash
  26. from peewee import fn
  27. from api import settings
  28. from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
  29. from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant
  30. from api.db.db_utils import bulk_insert_into_db
  31. from api.db.services.common_service import CommonService
  32. from api.db.services.knowledgebase_service import KnowledgebaseService
  33. from api.utils import current_timestamp, get_format_time, get_uuid
  34. from rag.nlp import rag_tokenizer, search
  35. from rag.settings import get_svr_queue_name
  36. from rag.utils.redis_conn import REDIS_CONN
  37. from rag.utils.storage_factory import STORAGE_IMPL
  38. from rag.utils.doc_store_conn import OrderByExpr
  39. class DocumentService(CommonService):
  40. model = Document
  41. @classmethod
  42. @DB.connection_context()
  43. def get_list(cls, kb_id, page_number, items_per_page,
  44. orderby, desc, keywords, id, name):
  45. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  46. if id:
  47. docs = docs.where(
  48. cls.model.id == id)
  49. if name:
  50. docs = docs.where(
  51. cls.model.name == name
  52. )
  53. if keywords:
  54. docs = docs.where(
  55. fn.LOWER(cls.model.name).contains(keywords.lower())
  56. )
  57. if desc:
  58. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  59. else:
  60. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  61. count = docs.count()
  62. docs = docs.paginate(page_number, items_per_page)
  63. return list(docs.dicts()), count
  64. @classmethod
  65. @DB.connection_context()
  66. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  67. orderby, desc, keywords, run_status, types):
  68. if keywords:
  69. docs = cls.model.select().where(
  70. (cls.model.kb_id == kb_id),
  71. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  72. )
  73. else:
  74. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  75. if run_status:
  76. docs = docs.where(cls.model.run.in_(run_status))
  77. if types:
  78. docs = docs.where(cls.model.type.in_(types))
  79. count = docs.count()
  80. if desc:
  81. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  82. else:
  83. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  84. if page_number and items_per_page:
  85. docs = docs.paginate(page_number, items_per_page)
  86. return list(docs.dicts()), count
  87. @classmethod
  88. @DB.connection_context()
  89. def count_by_kb_id(cls, kb_id, keywords, run_status, types):
  90. if keywords:
  91. docs = cls.model.select().where(
  92. (cls.model.kb_id == kb_id),
  93. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  94. )
  95. else:
  96. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  97. if run_status:
  98. docs = docs.where(cls.model.run.in_(run_status))
  99. if types:
  100. docs = docs.where(cls.model.type.in_(types))
  101. count = docs.count()
  102. return count
  103. @classmethod
  104. @DB.connection_context()
  105. def insert(cls, doc):
  106. if not cls.save(**doc):
  107. raise RuntimeError("Database error (Document)!")
  108. if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
  109. raise RuntimeError("Database error (Knowledgebase)!")
  110. return Document(**doc)
  111. @classmethod
  112. @DB.connection_context()
  113. def remove_document(cls, doc, tenant_id):
  114. cls.clear_chunk_num(doc.id)
  115. try:
  116. settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
  117. graph_source = settings.docStoreConn.getFields(
  118. settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
  119. )
  120. if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
  121. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
  122. {"remove": {"source_id": doc.id}},
  123. search.index_name(tenant_id), doc.kb_id)
  124. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
  125. {"removed_kwd": "Y"},
  126. search.index_name(tenant_id), doc.kb_id)
  127. settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
  128. search.index_name(tenant_id), doc.kb_id)
  129. except Exception:
  130. pass
  131. return cls.delete_by_id(doc.id)
  132. @classmethod
  133. @DB.connection_context()
  134. def get_newly_uploaded(cls):
  135. fields = [
  136. cls.model.id,
  137. cls.model.kb_id,
  138. cls.model.parser_id,
  139. cls.model.parser_config,
  140. cls.model.name,
  141. cls.model.type,
  142. cls.model.location,
  143. cls.model.size,
  144. Knowledgebase.tenant_id,
  145. Tenant.embd_id,
  146. Tenant.img2txt_id,
  147. Tenant.asr_id,
  148. cls.model.update_time]
  149. docs = cls.model.select(*fields) \
  150. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  151. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  152. .where(
  153. cls.model.status == StatusEnum.VALID.value,
  154. ~(cls.model.type == FileType.VIRTUAL.value),
  155. cls.model.progress == 0,
  156. cls.model.update_time >= current_timestamp() - 1000 * 600,
  157. cls.model.run == TaskStatus.RUNNING.value) \
  158. .order_by(cls.model.update_time.asc())
  159. return list(docs.dicts())
  160. @classmethod
  161. @DB.connection_context()
  162. def get_unfinished_docs(cls):
  163. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  164. cls.model.run, cls.model.parser_id]
  165. docs = cls.model.select(*fields) \
  166. .where(
  167. cls.model.status == StatusEnum.VALID.value,
  168. ~(cls.model.type == FileType.VIRTUAL.value),
  169. cls.model.progress < 1,
  170. cls.model.progress > 0)
  171. return list(docs.dicts())
  172. @classmethod
  173. @DB.connection_context()
  174. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  175. num = cls.model.update(token_num=cls.model.token_num + token_num,
  176. chunk_num=cls.model.chunk_num + chunk_num,
  177. process_duation=cls.model.process_duation + duation).where(
  178. cls.model.id == doc_id).execute()
  179. if num == 0:
  180. raise LookupError(
  181. "Document not found which is supposed to be there")
  182. num = Knowledgebase.update(
  183. token_num=Knowledgebase.token_num +
  184. token_num,
  185. chunk_num=Knowledgebase.chunk_num +
  186. chunk_num).where(
  187. Knowledgebase.id == kb_id).execute()
  188. return num
  189. @classmethod
  190. @DB.connection_context()
  191. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
  192. num = cls.model.update(token_num=cls.model.token_num - token_num,
  193. chunk_num=cls.model.chunk_num - chunk_num,
  194. process_duation=cls.model.process_duation + duation).where(
  195. cls.model.id == doc_id).execute()
  196. if num == 0:
  197. raise LookupError(
  198. "Document not found which is supposed to be there")
  199. num = Knowledgebase.update(
  200. token_num=Knowledgebase.token_num -
  201. token_num,
  202. chunk_num=Knowledgebase.chunk_num -
  203. chunk_num
  204. ).where(
  205. Knowledgebase.id == kb_id).execute()
  206. return num
  207. @classmethod
  208. @DB.connection_context()
  209. def clear_chunk_num(cls, doc_id):
  210. doc = cls.model.get_by_id(doc_id)
  211. assert doc, "Can't fine document in database."
  212. num = Knowledgebase.update(
  213. token_num=Knowledgebase.token_num -
  214. doc.token_num,
  215. chunk_num=Knowledgebase.chunk_num -
  216. doc.chunk_num,
  217. doc_num=Knowledgebase.doc_num - 1
  218. ).where(
  219. Knowledgebase.id == doc.kb_id).execute()
  220. return num
  221. @classmethod
  222. @DB.connection_context()
  223. def get_tenant_id(cls, doc_id):
  224. docs = cls.model.select(
  225. Knowledgebase.tenant_id).join(
  226. Knowledgebase, on=(
  227. Knowledgebase.id == cls.model.kb_id)).where(
  228. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  229. docs = docs.dicts()
  230. if not docs:
  231. return
  232. return docs[0]["tenant_id"]
  233. @classmethod
  234. @DB.connection_context()
  235. def get_knowledgebase_id(cls, doc_id):
  236. docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
  237. docs = docs.dicts()
  238. if not docs:
  239. return
  240. return docs[0]["kb_id"]
  241. @classmethod
  242. @DB.connection_context()
  243. def get_tenant_id_by_name(cls, name):
  244. docs = cls.model.select(
  245. Knowledgebase.tenant_id).join(
  246. Knowledgebase, on=(
  247. Knowledgebase.id == cls.model.kb_id)).where(
  248. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  249. docs = docs.dicts()
  250. if not docs:
  251. return
  252. return docs[0]["tenant_id"]
  253. @classmethod
  254. @DB.connection_context()
  255. def accessible(cls, doc_id, user_id):
  256. docs = cls.model.select(
  257. cls.model.id).join(
  258. Knowledgebase, on=(
  259. Knowledgebase.id == cls.model.kb_id)
  260. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  261. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  262. docs = docs.dicts()
  263. if not docs:
  264. return False
  265. return True
  266. @classmethod
  267. @DB.connection_context()
  268. def accessible4deletion(cls, doc_id, user_id):
  269. docs = cls.model.select(cls.model.id
  270. ).join(
  271. Knowledgebase, on=(
  272. Knowledgebase.id == cls.model.kb_id)
  273. ).join(
  274. UserTenant, on=(
  275. (UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
  276. ).where(
  277. cls.model.id == doc_id,
  278. UserTenant.status == StatusEnum.VALID.value,
  279. ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
  280. ).paginate(0, 1)
  281. docs = docs.dicts()
  282. if not docs:
  283. return False
  284. return True
  285. @classmethod
  286. @DB.connection_context()
  287. def get_embd_id(cls, doc_id):
  288. docs = cls.model.select(
  289. Knowledgebase.embd_id).join(
  290. Knowledgebase, on=(
  291. Knowledgebase.id == cls.model.kb_id)).where(
  292. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  293. docs = docs.dicts()
  294. if not docs:
  295. return
  296. return docs[0]["embd_id"]
  297. @classmethod
  298. @DB.connection_context()
  299. def get_chunking_config(cls, doc_id):
  300. configs = (
  301. cls.model.select(
  302. cls.model.id,
  303. cls.model.kb_id,
  304. cls.model.parser_id,
  305. cls.model.parser_config,
  306. Knowledgebase.language,
  307. Knowledgebase.embd_id,
  308. Tenant.id.alias("tenant_id"),
  309. Tenant.img2txt_id,
  310. Tenant.asr_id,
  311. Tenant.llm_id,
  312. )
  313. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
  314. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
  315. .where(cls.model.id == doc_id)
  316. )
  317. configs = configs.dicts()
  318. if not configs:
  319. return None
  320. return configs[0]
  321. @classmethod
  322. @DB.connection_context()
  323. def get_doc_id_by_doc_name(cls, doc_name):
  324. fields = [cls.model.id]
  325. doc_id = cls.model.select(*fields) \
  326. .where(cls.model.name == doc_name)
  327. doc_id = doc_id.dicts()
  328. if not doc_id:
  329. return
  330. return doc_id[0]["id"]
  331. @classmethod
  332. @DB.connection_context()
  333. def get_thumbnails(cls, docids):
  334. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  335. return list(cls.model.select(
  336. *fields).where(cls.model.id.in_(docids)).dicts())
  337. @classmethod
  338. @DB.connection_context()
  339. def update_parser_config(cls, id, config):
  340. if not config:
  341. return
  342. e, d = cls.get_by_id(id)
  343. if not e:
  344. raise LookupError(f"Document({id}) not found.")
  345. def dfs_update(old, new):
  346. for k, v in new.items():
  347. if k not in old:
  348. old[k] = v
  349. continue
  350. if isinstance(v, dict):
  351. assert isinstance(old[k], dict)
  352. dfs_update(old[k], v)
  353. else:
  354. old[k] = v
  355. dfs_update(d.parser_config, config)
  356. if not config.get("raptor") and d.parser_config.get("raptor"):
  357. del d.parser_config["raptor"]
  358. cls.update_by_id(id, {"parser_config": d.parser_config})
  359. @classmethod
  360. @DB.connection_context()
  361. def get_doc_count(cls, tenant_id):
  362. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  363. on=(Knowledgebase.id == cls.model.kb_id)).where(
  364. Knowledgebase.tenant_id == tenant_id)
  365. return len(docs)
  366. @classmethod
  367. @DB.connection_context()
  368. def begin2parse(cls, docid):
  369. cls.update_by_id(
  370. docid, {"progress": random.random() * 1 / 100.,
  371. "progress_msg": "Task is queued...",
  372. "process_begin_at": get_format_time()
  373. })
  374. @classmethod
  375. @DB.connection_context()
  376. def update_meta_fields(cls, doc_id, meta_fields):
  377. return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
  378. @classmethod
  379. @DB.connection_context()
  380. def update_progress(cls):
  381. docs = cls.get_unfinished_docs()
  382. for d in docs:
  383. try:
  384. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  385. if not tsks:
  386. continue
  387. msg = []
  388. prg = 0
  389. finished = True
  390. bad = 0
  391. has_raptor = False
  392. has_graphrag = False
  393. e, doc = DocumentService.get_by_id(d["id"])
  394. status = doc.run # TaskStatus.RUNNING.value
  395. priority = 0
  396. for t in tsks:
  397. if 0 <= t.progress < 1:
  398. finished = False
  399. if t.progress == -1:
  400. bad += 1
  401. prg += t.progress if t.progress >= 0 else 0
  402. msg.append(t.progress_msg)
  403. if t.task_type == "raptor":
  404. has_raptor = True
  405. elif t.task_type == "graphrag":
  406. has_graphrag = True
  407. priority = max(priority, t.priority)
  408. prg /= len(tsks)
  409. if finished and bad:
  410. prg = -1
  411. status = TaskStatus.FAIL.value
  412. elif finished:
  413. if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
  414. queue_raptor_o_graphrag_tasks(d, "raptor", priority)
  415. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  416. elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
  417. queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
  418. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  419. else:
  420. status = TaskStatus.DONE.value
  421. msg = "\n".join(sorted(msg))
  422. info = {
  423. "process_duation": datetime.timestamp(
  424. datetime.now()) -
  425. d["process_begin_at"].timestamp(),
  426. "run": status}
  427. if prg != 0:
  428. info["progress"] = prg
  429. if msg:
  430. info["progress_msg"] = msg
  431. cls.update_by_id(d["id"], info)
  432. except Exception as e:
  433. if str(e).find("'0'") < 0:
  434. logging.exception("fetch task exception")
  435. @classmethod
  436. @DB.connection_context()
  437. def get_kb_doc_count(cls, kb_id):
  438. return len(cls.model.select(cls.model.id).where(
  439. cls.model.kb_id == kb_id).dicts())
  440. @classmethod
  441. @DB.connection_context()
  442. def do_cancel(cls, doc_id):
  443. try:
  444. _, doc = DocumentService.get_by_id(doc_id)
  445. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  446. except Exception:
  447. pass
  448. return False
  449. def queue_raptor_o_graphrag_tasks(doc, ty, priority):
  450. chunking_config = DocumentService.get_chunking_config(doc["id"])
  451. hasher = xxhash.xxh64()
  452. for field in sorted(chunking_config.keys()):
  453. hasher.update(str(chunking_config[field]).encode("utf-8"))
  454. def new_task():
  455. nonlocal doc
  456. return {
  457. "id": get_uuid(),
  458. "doc_id": doc["id"],
  459. "from_page": 100000000,
  460. "to_page": 100000000,
  461. "task_type": ty,
  462. "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
  463. }
  464. task = new_task()
  465. for field in ["doc_id", "from_page", "to_page"]:
  466. hasher.update(str(task.get(field, "")).encode("utf-8"))
  467. hasher.update(ty.encode("utf-8"))
  468. task["digest"] = hasher.hexdigest()
  469. bulk_insert_into_db(Task, [task], True)
  470. assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
  471. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  472. from api.db.services.api_service import API4ConversationService
  473. from api.db.services.conversation_service import ConversationService
  474. from api.db.services.dialog_service import DialogService
  475. from api.db.services.file_service import FileService
  476. from api.db.services.llm_service import LLMBundle
  477. from api.db.services.user_service import TenantService
  478. from rag.app import audio, email, naive, picture, presentation
  479. e, conv = ConversationService.get_by_id(conversation_id)
  480. if not e:
  481. e, conv = API4ConversationService.get_by_id(conversation_id)
  482. assert e, "Conversation not found!"
  483. e, dia = DialogService.get_by_id(conv.dialog_id)
  484. if not dia.kb_ids:
  485. raise LookupError("No knowledge base associated with this conversation. "
  486. "Please add a knowledge base before uploading documents")
  487. kb_id = dia.kb_ids[0]
  488. e, kb = KnowledgebaseService.get_by_id(kb_id)
  489. if not e:
  490. raise LookupError("Can't find this knowledgebase!")
  491. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  492. err, files = FileService.upload_document(kb, file_objs, user_id)
  493. assert not err, "\n".join(err)
  494. def dummy(prog=None, msg=""):
  495. pass
  496. FACTORY = {
  497. ParserType.PRESENTATION.value: presentation,
  498. ParserType.PICTURE.value: picture,
  499. ParserType.AUDIO.value: audio,
  500. ParserType.EMAIL.value: email
  501. }
  502. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
  503. exe = ThreadPoolExecutor(max_workers=12)
  504. threads = []
  505. doc_nm = {}
  506. for d, blob in files:
  507. doc_nm[d["id"]] = d["name"]
  508. for d, blob in files:
  509. kwargs = {
  510. "callback": dummy,
  511. "parser_config": parser_config,
  512. "from_page": 0,
  513. "to_page": 100000,
  514. "tenant_id": kb.tenant_id,
  515. "lang": kb.language
  516. }
  517. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  518. for (docinfo, _), th in zip(files, threads):
  519. docs = []
  520. doc = {
  521. "doc_id": docinfo["id"],
  522. "kb_id": [kb.id]
  523. }
  524. for ck in th.result():
  525. d = deepcopy(doc)
  526. d.update(ck)
  527. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  528. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  529. d["create_timestamp_flt"] = datetime.now().timestamp()
  530. if not d.get("image"):
  531. docs.append(d)
  532. continue
  533. output_buffer = BytesIO()
  534. if isinstance(d["image"], bytes):
  535. output_buffer = BytesIO(d["image"])
  536. else:
  537. d["image"].save(output_buffer, format='JPEG')
  538. STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
  539. d["img_id"] = "{}-{}".format(kb.id, d["id"])
  540. d.pop("image", None)
  541. docs.append(d)
  542. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  543. docids = [d["id"] for d, _ in files]
  544. chunk_counts = {id: 0 for id in docids}
  545. token_counts = {id: 0 for id in docids}
  546. es_bulk_size = 64
  547. def embedding(doc_id, cnts, batch_size=16):
  548. nonlocal embd_mdl, chunk_counts, token_counts
  549. vects = []
  550. for i in range(0, len(cnts), batch_size):
  551. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  552. vects.extend(vts.tolist())
  553. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  554. token_counts[doc_id] += c
  555. return vects
  556. idxnm = search.index_name(kb.tenant_id)
  557. try_create_idx = True
  558. _, tenant = TenantService.get_by_id(kb.tenant_id)
  559. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  560. for doc_id in docids:
  561. cks = [c for c in docs if c["doc_id"] == doc_id]
  562. if parser_ids[doc_id] != ParserType.PICTURE.value:
  563. from graphrag.general.mind_map_extractor import MindMapExtractor
  564. mindmap = MindMapExtractor(llm_bdl)
  565. try:
  566. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
  567. mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
  568. if len(mind_map) < 32:
  569. raise Exception("Few content: " + mind_map)
  570. cks.append({
  571. "id": get_uuid(),
  572. "doc_id": doc_id,
  573. "kb_id": [kb.id],
  574. "docnm_kwd": doc_nm[doc_id],
  575. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  576. "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
  577. "content_with_weight": mind_map,
  578. "knowledge_graph_kwd": "mind_map"
  579. })
  580. except Exception as e:
  581. logging.exception("Mind map generation error")
  582. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  583. assert len(cks) == len(vects)
  584. for i, d in enumerate(cks):
  585. v = vects[i]
  586. d["q_%d_vec" % len(v)] = v
  587. for b in range(0, len(cks), es_bulk_size):
  588. if try_create_idx:
  589. if not settings.docStoreConn.indexExist(idxnm, kb_id):
  590. settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
  591. try_create_idx = False
  592. settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
  593. DocumentService.increment_chunk_num(
  594. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  595. return [d["id"] for d, _ in files]