Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

document_service.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import random
  19. import re
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from datetime import datetime
  23. from io import BytesIO
  24. import trio
  25. import xxhash
  26. from peewee import fn
  27. from api import settings
  28. from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
  29. from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
  30. from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File
  31. from api.db.db_utils import bulk_insert_into_db
  32. from api.db.services.common_service import CommonService
  33. from api.db.services.knowledgebase_service import KnowledgebaseService
  34. from api.utils import current_timestamp, get_format_time, get_uuid
  35. from rag.nlp import rag_tokenizer, search
  36. from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
  37. from rag.utils.redis_conn import REDIS_CONN
  38. from rag.utils.storage_factory import STORAGE_IMPL
  39. from rag.utils.doc_store_conn import OrderByExpr
  40. class DocumentService(CommonService):
  41. model = Document
  42. @classmethod
  43. def get_cls_model_fields(cls):
  44. return [
  45. cls.model.id,
  46. cls.model.thumbnail,
  47. cls.model.kb_id,
  48. cls.model.parser_id,
  49. cls.model.parser_config,
  50. cls.model.source_type,
  51. cls.model.type,
  52. cls.model.created_by,
  53. cls.model.name,
  54. cls.model.location,
  55. cls.model.size,
  56. cls.model.token_num,
  57. cls.model.chunk_num,
  58. cls.model.progress,
  59. cls.model.progress_msg,
  60. cls.model.process_begin_at,
  61. cls.model.process_duration,
  62. cls.model.meta_fields,
  63. cls.model.suffix,
  64. cls.model.run,
  65. cls.model.status,
  66. cls.model.create_time,
  67. cls.model.create_date,
  68. cls.model.update_time,
  69. cls.model.update_date,
  70. ]
  71. @classmethod
  72. @DB.connection_context()
  73. def get_list(cls, kb_id, page_number, items_per_page,
  74. orderby, desc, keywords, id, name):
  75. fields = cls.get_cls_model_fields()
  76. docs = cls.model.select(*fields).join(File2Document, on = (File2Document.document_id == cls.model.id)).join(File, on = (File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
  77. if id:
  78. docs = docs.where(
  79. cls.model.id == id)
  80. if name:
  81. docs = docs.where(
  82. cls.model.name == name
  83. )
  84. if keywords:
  85. docs = docs.where(
  86. fn.LOWER(cls.model.name).contains(keywords.lower())
  87. )
  88. if desc:
  89. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  90. else:
  91. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  92. count = docs.count()
  93. docs = docs.paginate(page_number, items_per_page)
  94. return list(docs.dicts()), count
  95. @classmethod
  96. @DB.connection_context()
  97. def check_doc_health(cls, tenant_id: str, filename):
  98. import os
  99. MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
  100. if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
  101. raise RuntimeError("Exceed the maximum file number of a free user!")
  102. if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
  103. raise RuntimeError("Exceed the maximum length of file name!")
  104. return True
  105. @classmethod
  106. @DB.connection_context()
  107. def get_by_kb_id(cls, kb_id, page_number, items_per_page,
  108. orderby, desc, keywords, run_status, types, suffix):
  109. fields = cls.get_cls_model_fields()
  110. if keywords:
  111. docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
  112. (cls.model.kb_id == kb_id),
  113. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  114. )
  115. else:
  116. docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
  117. if run_status:
  118. docs = docs.where(cls.model.run.in_(run_status))
  119. if types:
  120. docs = docs.where(cls.model.type.in_(types))
  121. if suffix:
  122. docs = docs.where(cls.model.suffix.in_(suffix))
  123. count = docs.count()
  124. if desc:
  125. docs = docs.order_by(cls.model.getter_by(orderby).desc())
  126. else:
  127. docs = docs.order_by(cls.model.getter_by(orderby).asc())
  128. if page_number and items_per_page:
  129. docs = docs.paginate(page_number, items_per_page)
  130. return list(docs.dicts()), count
  131. @classmethod
  132. @DB.connection_context()
  133. def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
  134. """
  135. returns:
  136. {
  137. "suffix": {
  138. "ppt": 1,
  139. "doxc": 2
  140. },
  141. "run_status": {
  142. "1": 2,
  143. "2": 2
  144. }
  145. }, total
  146. where "1" => RUNNING, "2" => CANCEL
  147. """
  148. fields = cls.get_cls_model_fields()
  149. if keywords:
  150. query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
  151. (cls.model.kb_id == kb_id),
  152. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  153. )
  154. else:
  155. query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
  156. if run_status:
  157. query = query.where(cls.model.run.in_(run_status))
  158. if types:
  159. query = query.where(cls.model.type.in_(types))
  160. if suffix:
  161. query = query.where(cls.model.suffix.in_(suffix))
  162. rows = query.select(cls.model.run, cls.model.suffix)
  163. total = rows.count()
  164. suffix_counter = {}
  165. run_status_counter = {}
  166. for row in rows:
  167. suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
  168. run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
  169. return {
  170. "suffix": suffix_counter,
  171. "run_status": run_status_counter
  172. }, total
  173. @classmethod
  174. @DB.connection_context()
  175. def count_by_kb_id(cls, kb_id, keywords, run_status, types):
  176. if keywords:
  177. docs = cls.model.select().where(
  178. (cls.model.kb_id == kb_id),
  179. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  180. )
  181. else:
  182. docs = cls.model.select().where(cls.model.kb_id == kb_id)
  183. if run_status:
  184. docs = docs.where(cls.model.run.in_(run_status))
  185. if types:
  186. docs = docs.where(cls.model.type.in_(types))
  187. count = docs.count()
  188. return count
  189. @classmethod
  190. @DB.connection_context()
  191. def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
  192. query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
  193. cls.model.kb_id == kb_id
  194. )
  195. if keywords:
  196. query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
  197. if run_status:
  198. query = query.where(cls.model.run.in_(run_status))
  199. if types:
  200. query = query.where(cls.model.type.in_(types))
  201. return int(query.scalar()) or 0
  202. @classmethod
  203. @DB.connection_context()
  204. def insert(cls, doc):
  205. if not cls.save(**doc):
  206. raise RuntimeError("Database error (Document)!")
  207. if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
  208. raise RuntimeError("Database error (Knowledgebase)!")
  209. return Document(**doc)
  210. @classmethod
  211. @DB.connection_context()
  212. def remove_document(cls, doc, tenant_id):
  213. from api.db.services.task_service import TaskService
  214. cls.clear_chunk_num(doc.id)
  215. try:
  216. TaskService.filter_delete([Task.doc_id == doc.id])
  217. page = 0
  218. page_size = 1000
  219. all_chunk_ids = []
  220. while True:
  221. chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
  222. page * page_size, page_size, search.index_name(tenant_id),
  223. [doc.kb_id])
  224. chunk_ids = settings.docStoreConn.getChunkIds(chunks)
  225. if not chunk_ids:
  226. break
  227. all_chunk_ids.extend(chunk_ids)
  228. page += 1
  229. for cid in all_chunk_ids:
  230. if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
  231. STORAGE_IMPL.rm(doc.kb_id, cid)
  232. if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
  233. if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
  234. STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
  235. settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
  236. graph_source = settings.docStoreConn.getFields(
  237. settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
  238. )
  239. if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
  240. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
  241. {"remove": {"source_id": doc.id}},
  242. search.index_name(tenant_id), doc.kb_id)
  243. settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
  244. {"removed_kwd": "Y"},
  245. search.index_name(tenant_id), doc.kb_id)
  246. settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
  247. search.index_name(tenant_id), doc.kb_id)
  248. except Exception:
  249. pass
  250. return cls.delete_by_id(doc.id)
  251. @classmethod
  252. @DB.connection_context()
  253. def get_newly_uploaded(cls):
  254. fields = [
  255. cls.model.id,
  256. cls.model.kb_id,
  257. cls.model.parser_id,
  258. cls.model.parser_config,
  259. cls.model.name,
  260. cls.model.type,
  261. cls.model.location,
  262. cls.model.size,
  263. Knowledgebase.tenant_id,
  264. Tenant.embd_id,
  265. Tenant.img2txt_id,
  266. Tenant.asr_id,
  267. cls.model.update_time]
  268. docs = cls.model.select(*fields) \
  269. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
  270. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
  271. .where(
  272. cls.model.status == StatusEnum.VALID.value,
  273. ~(cls.model.type == FileType.VIRTUAL.value),
  274. cls.model.progress == 0,
  275. cls.model.update_time >= current_timestamp() - 1000 * 600,
  276. cls.model.run == TaskStatus.RUNNING.value) \
  277. .order_by(cls.model.update_time.asc())
  278. return list(docs.dicts())
  279. @classmethod
  280. @DB.connection_context()
  281. def get_unfinished_docs(cls):
  282. fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
  283. cls.model.run, cls.model.parser_id]
  284. docs = cls.model.select(*fields) \
  285. .where(
  286. cls.model.status == StatusEnum.VALID.value,
  287. ~(cls.model.type == FileType.VIRTUAL.value),
  288. cls.model.progress < 1,
  289. cls.model.progress > 0)
  290. return list(docs.dicts())
  291. @classmethod
  292. @DB.connection_context()
  293. def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
  294. num = cls.model.update(token_num=cls.model.token_num + token_num,
  295. chunk_num=cls.model.chunk_num + chunk_num,
  296. process_duration=cls.model.process_duration + duration).where(
  297. cls.model.id == doc_id).execute()
  298. if num == 0:
  299. raise LookupError(
  300. "Document not found which is supposed to be there")
  301. num = Knowledgebase.update(
  302. token_num=Knowledgebase.token_num +
  303. token_num,
  304. chunk_num=Knowledgebase.chunk_num +
  305. chunk_num).where(
  306. Knowledgebase.id == kb_id).execute()
  307. return num
  308. @classmethod
  309. @DB.connection_context()
  310. def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
  311. num = cls.model.update(token_num=cls.model.token_num - token_num,
  312. chunk_num=cls.model.chunk_num - chunk_num,
  313. process_duration=cls.model.process_duration + duration).where(
  314. cls.model.id == doc_id).execute()
  315. if num == 0:
  316. raise LookupError(
  317. "Document not found which is supposed to be there")
  318. num = Knowledgebase.update(
  319. token_num=Knowledgebase.token_num -
  320. token_num,
  321. chunk_num=Knowledgebase.chunk_num -
  322. chunk_num
  323. ).where(
  324. Knowledgebase.id == kb_id).execute()
  325. return num
  326. @classmethod
  327. @DB.connection_context()
  328. def clear_chunk_num(cls, doc_id):
  329. doc = cls.model.get_by_id(doc_id)
  330. assert doc, "Can't fine document in database."
  331. num = Knowledgebase.update(
  332. token_num=Knowledgebase.token_num -
  333. doc.token_num,
  334. chunk_num=Knowledgebase.chunk_num -
  335. doc.chunk_num,
  336. doc_num=Knowledgebase.doc_num - 1
  337. ).where(
  338. Knowledgebase.id == doc.kb_id).execute()
  339. return num
  340. @classmethod
  341. @DB.connection_context()
  342. def clear_chunk_num_when_rerun(cls, doc_id):
  343. doc = cls.model.get_by_id(doc_id)
  344. assert doc, "Can't fine document in database."
  345. num = (
  346. Knowledgebase.update(
  347. token_num=Knowledgebase.token_num - doc.token_num,
  348. chunk_num=Knowledgebase.chunk_num - doc.chunk_num,
  349. )
  350. .where(Knowledgebase.id == doc.kb_id)
  351. .execute()
  352. )
  353. return num
  354. @classmethod
  355. @DB.connection_context()
  356. def get_tenant_id(cls, doc_id):
  357. docs = cls.model.select(
  358. Knowledgebase.tenant_id).join(
  359. Knowledgebase, on=(
  360. Knowledgebase.id == cls.model.kb_id)).where(
  361. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  362. docs = docs.dicts()
  363. if not docs:
  364. return
  365. return docs[0]["tenant_id"]
  366. @classmethod
  367. @DB.connection_context()
  368. def get_knowledgebase_id(cls, doc_id):
  369. docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
  370. docs = docs.dicts()
  371. if not docs:
  372. return
  373. return docs[0]["kb_id"]
  374. @classmethod
  375. @DB.connection_context()
  376. def get_tenant_id_by_name(cls, name):
  377. docs = cls.model.select(
  378. Knowledgebase.tenant_id).join(
  379. Knowledgebase, on=(
  380. Knowledgebase.id == cls.model.kb_id)).where(
  381. cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
  382. docs = docs.dicts()
  383. if not docs:
  384. return
  385. return docs[0]["tenant_id"]
  386. @classmethod
  387. @DB.connection_context()
  388. def accessible(cls, doc_id, user_id):
  389. docs = cls.model.select(
  390. cls.model.id).join(
  391. Knowledgebase, on=(
  392. Knowledgebase.id == cls.model.kb_id)
  393. ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  394. ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
  395. docs = docs.dicts()
  396. if not docs:
  397. return False
  398. return True
  399. @classmethod
  400. @DB.connection_context()
  401. def accessible4deletion(cls, doc_id, user_id):
  402. docs = cls.model.select(cls.model.id
  403. ).join(
  404. Knowledgebase, on=(
  405. Knowledgebase.id == cls.model.kb_id)
  406. ).join(
  407. UserTenant, on=(
  408. (UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
  409. ).where(
  410. cls.model.id == doc_id,
  411. UserTenant.status == StatusEnum.VALID.value,
  412. ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
  413. ).paginate(0, 1)
  414. docs = docs.dicts()
  415. if not docs:
  416. return False
  417. return True
  418. @classmethod
  419. @DB.connection_context()
  420. def get_embd_id(cls, doc_id):
  421. docs = cls.model.select(
  422. Knowledgebase.embd_id).join(
  423. Knowledgebase, on=(
  424. Knowledgebase.id == cls.model.kb_id)).where(
  425. cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
  426. docs = docs.dicts()
  427. if not docs:
  428. return
  429. return docs[0]["embd_id"]
  430. @classmethod
  431. @DB.connection_context()
  432. def get_chunking_config(cls, doc_id):
  433. configs = (
  434. cls.model.select(
  435. cls.model.id,
  436. cls.model.kb_id,
  437. cls.model.parser_id,
  438. cls.model.parser_config,
  439. Knowledgebase.language,
  440. Knowledgebase.embd_id,
  441. Tenant.id.alias("tenant_id"),
  442. Tenant.img2txt_id,
  443. Tenant.asr_id,
  444. Tenant.llm_id,
  445. )
  446. .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
  447. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
  448. .where(cls.model.id == doc_id)
  449. )
  450. configs = configs.dicts()
  451. if not configs:
  452. return None
  453. return configs[0]
  454. @classmethod
  455. @DB.connection_context()
  456. def get_doc_id_by_doc_name(cls, doc_name):
  457. fields = [cls.model.id]
  458. doc_id = cls.model.select(*fields) \
  459. .where(cls.model.name == doc_name)
  460. doc_id = doc_id.dicts()
  461. if not doc_id:
  462. return
  463. return doc_id[0]["id"]
  464. @classmethod
  465. @DB.connection_context()
  466. def get_doc_ids_by_doc_names(cls, doc_names):
  467. if not doc_names:
  468. return []
  469. query = cls.model.select(cls.model.id).where(cls.model.name.in_(doc_names))
  470. return list(query.scalars().iterator())
  471. @classmethod
  472. @DB.connection_context()
  473. def get_thumbnails(cls, docids):
  474. fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
  475. return list(cls.model.select(
  476. *fields).where(cls.model.id.in_(docids)).dicts())
  477. @classmethod
  478. @DB.connection_context()
  479. def update_parser_config(cls, id, config):
  480. if not config:
  481. return
  482. e, d = cls.get_by_id(id)
  483. if not e:
  484. raise LookupError(f"Document({id}) not found.")
  485. def dfs_update(old, new):
  486. for k, v in new.items():
  487. if k not in old:
  488. old[k] = v
  489. continue
  490. if isinstance(v, dict):
  491. assert isinstance(old[k], dict)
  492. dfs_update(old[k], v)
  493. else:
  494. old[k] = v
  495. dfs_update(d.parser_config, config)
  496. if not config.get("raptor") and d.parser_config.get("raptor"):
  497. del d.parser_config["raptor"]
  498. cls.update_by_id(id, {"parser_config": d.parser_config})
  499. @classmethod
  500. @DB.connection_context()
  501. def get_doc_count(cls, tenant_id):
  502. docs = cls.model.select(cls.model.id).join(Knowledgebase,
  503. on=(Knowledgebase.id == cls.model.kb_id)).where(
  504. Knowledgebase.tenant_id == tenant_id)
  505. return len(docs)
  506. @classmethod
  507. @DB.connection_context()
  508. def begin2parse(cls, docid):
  509. cls.update_by_id(
  510. docid, {"progress": random.random() * 1 / 100.,
  511. "progress_msg": "Task is queued...",
  512. "process_begin_at": get_format_time()
  513. })
  514. @classmethod
  515. @DB.connection_context()
  516. def update_meta_fields(cls, doc_id, meta_fields):
  517. return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
  518. @classmethod
  519. @DB.connection_context()
  520. def get_meta_by_kbs(cls, kb_ids):
  521. fields = [
  522. cls.model.id,
  523. cls.model.meta_fields,
  524. ]
  525. meta = {}
  526. for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
  527. doc_id = r.id
  528. for k,v in r.meta_fields.items():
  529. if k not in meta:
  530. meta[k] = {}
  531. v = str(v)
  532. if v not in meta[k]:
  533. meta[k][v] = []
  534. meta[k][v].append(doc_id)
  535. return meta
  536. @classmethod
  537. @DB.connection_context()
  538. def update_progress(cls):
  539. docs = cls.get_unfinished_docs()
  540. for d in docs:
  541. try:
  542. tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
  543. if not tsks:
  544. continue
  545. msg = []
  546. prg = 0
  547. finished = True
  548. bad = 0
  549. has_raptor = False
  550. has_graphrag = False
  551. e, doc = DocumentService.get_by_id(d["id"])
  552. status = doc.run # TaskStatus.RUNNING.value
  553. priority = 0
  554. for t in tsks:
  555. if 0 <= t.progress < 1:
  556. finished = False
  557. if t.progress == -1:
  558. bad += 1
  559. prg += t.progress if t.progress >= 0 else 0
  560. if t.progress_msg.strip():
  561. msg.append(t.progress_msg)
  562. if t.task_type == "raptor":
  563. has_raptor = True
  564. elif t.task_type == "graphrag":
  565. has_graphrag = True
  566. priority = max(priority, t.priority)
  567. prg /= len(tsks)
  568. if finished and bad:
  569. prg = -1
  570. status = TaskStatus.FAIL.value
  571. elif finished:
  572. if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor:
  573. queue_raptor_o_graphrag_tasks(d, "raptor", priority)
  574. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  575. elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
  576. queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
  577. prg = 0.98 * len(tsks) / (len(tsks) + 1)
  578. else:
  579. status = TaskStatus.DONE.value
  580. msg = "\n".join(sorted(msg))
  581. info = {
  582. "process_duration": datetime.timestamp(
  583. datetime.now()) -
  584. d["process_begin_at"].timestamp(),
  585. "run": status}
  586. if prg != 0:
  587. info["progress"] = prg
  588. if msg:
  589. info["progress_msg"] = msg
  590. if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
  591. info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
  592. else:
  593. info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
  594. cls.update_by_id(d["id"], info)
  595. except Exception as e:
  596. if str(e).find("'0'") < 0:
  597. logging.exception("fetch task exception")
  598. @classmethod
  599. @DB.connection_context()
  600. def get_kb_doc_count(cls, kb_id):
  601. return len(cls.model.select(cls.model.id).where(
  602. cls.model.kb_id == kb_id).dicts())
  603. @classmethod
  604. @DB.connection_context()
  605. def do_cancel(cls, doc_id):
  606. try:
  607. _, doc = DocumentService.get_by_id(doc_id)
  608. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  609. except Exception:
  610. pass
  611. return False
  612. def queue_raptor_o_graphrag_tasks(doc, ty, priority):
  613. chunking_config = DocumentService.get_chunking_config(doc["id"])
  614. hasher = xxhash.xxh64()
  615. for field in sorted(chunking_config.keys()):
  616. hasher.update(str(chunking_config[field]).encode("utf-8"))
  617. def new_task():
  618. nonlocal doc
  619. return {
  620. "id": get_uuid(),
  621. "doc_id": doc["id"],
  622. "from_page": 100000000,
  623. "to_page": 100000000,
  624. "task_type": ty,
  625. "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
  626. }
  627. task = new_task()
  628. for field in ["doc_id", "from_page", "to_page"]:
  629. hasher.update(str(task.get(field, "")).encode("utf-8"))
  630. hasher.update(ty.encode("utf-8"))
  631. task["digest"] = hasher.hexdigest()
  632. bulk_insert_into_db(Task, [task], True)
  633. assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
  634. def get_queue_length(priority):
  635. group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
  636. return int(group_info.get("lag", 0) or 0)
  637. def doc_upload_and_parse(conversation_id, file_objs, user_id):
  638. from api.db.services.api_service import API4ConversationService
  639. from api.db.services.conversation_service import ConversationService
  640. from api.db.services.dialog_service import DialogService
  641. from api.db.services.file_service import FileService
  642. from api.db.services.llm_service import LLMBundle
  643. from api.db.services.user_service import TenantService
  644. from rag.app import audio, email, naive, picture, presentation
  645. e, conv = ConversationService.get_by_id(conversation_id)
  646. if not e:
  647. e, conv = API4ConversationService.get_by_id(conversation_id)
  648. assert e, "Conversation not found!"
  649. e, dia = DialogService.get_by_id(conv.dialog_id)
  650. if not dia.kb_ids:
  651. raise LookupError("No knowledge base associated with this conversation. "
  652. "Please add a knowledge base before uploading documents")
  653. kb_id = dia.kb_ids[0]
  654. e, kb = KnowledgebaseService.get_by_id(kb_id)
  655. if not e:
  656. raise LookupError("Can't find this knowledgebase!")
  657. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  658. err, files = FileService.upload_document(kb, file_objs, user_id)
  659. assert not err, "\n".join(err)
  660. def dummy(prog=None, msg=""):
  661. pass
  662. FACTORY = {
  663. ParserType.PRESENTATION.value: presentation,
  664. ParserType.PICTURE.value: picture,
  665. ParserType.AUDIO.value: audio,
  666. ParserType.EMAIL.value: email
  667. }
  668. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
  669. exe = ThreadPoolExecutor(max_workers=12)
  670. threads = []
  671. doc_nm = {}
  672. for d, blob in files:
  673. doc_nm[d["id"]] = d["name"]
  674. for d, blob in files:
  675. kwargs = {
  676. "callback": dummy,
  677. "parser_config": parser_config,
  678. "from_page": 0,
  679. "to_page": 100000,
  680. "tenant_id": kb.tenant_id,
  681. "lang": kb.language
  682. }
  683. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  684. for (docinfo, _), th in zip(files, threads):
  685. docs = []
  686. doc = {
  687. "doc_id": docinfo["id"],
  688. "kb_id": [kb.id]
  689. }
  690. for ck in th.result():
  691. d = deepcopy(doc)
  692. d.update(ck)
  693. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  694. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  695. d["create_timestamp_flt"] = datetime.now().timestamp()
  696. if not d.get("image"):
  697. docs.append(d)
  698. continue
  699. output_buffer = BytesIO()
  700. if isinstance(d["image"], bytes):
  701. output_buffer = BytesIO(d["image"])
  702. else:
  703. d["image"].save(output_buffer, format='JPEG')
  704. STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
  705. d["img_id"] = "{}-{}".format(kb.id, d["id"])
  706. d.pop("image", None)
  707. docs.append(d)
  708. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  709. docids = [d["id"] for d, _ in files]
  710. chunk_counts = {id: 0 for id in docids}
  711. token_counts = {id: 0 for id in docids}
  712. es_bulk_size = 64
  713. def embedding(doc_id, cnts, batch_size=16):
  714. nonlocal embd_mdl, chunk_counts, token_counts
  715. vects = []
  716. for i in range(0, len(cnts), batch_size):
  717. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  718. vects.extend(vts.tolist())
  719. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  720. token_counts[doc_id] += c
  721. return vects
  722. idxnm = search.index_name(kb.tenant_id)
  723. try_create_idx = True
  724. _, tenant = TenantService.get_by_id(kb.tenant_id)
  725. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  726. for doc_id in docids:
  727. cks = [c for c in docs if c["doc_id"] == doc_id]
  728. if parser_ids[doc_id] != ParserType.PICTURE.value:
  729. from graphrag.general.mind_map_extractor import MindMapExtractor
  730. mindmap = MindMapExtractor(llm_bdl)
  731. try:
  732. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
  733. mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
  734. if len(mind_map) < 32:
  735. raise Exception("Few content: " + mind_map)
  736. cks.append({
  737. "id": get_uuid(),
  738. "doc_id": doc_id,
  739. "kb_id": [kb.id],
  740. "docnm_kwd": doc_nm[doc_id],
  741. "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
  742. "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
  743. "content_with_weight": mind_map,
  744. "knowledge_graph_kwd": "mind_map"
  745. })
  746. except Exception as e:
  747. logging.exception("Mind map generation error")
  748. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  749. assert len(cks) == len(vects)
  750. for i, d in enumerate(cks):
  751. v = vects[i]
  752. d["q_%d_vec" % len(v)] = v
  753. for b in range(0, len(cks), es_bulk_size):
  754. if try_create_idx:
  755. if not settings.docStoreConn.indexExist(idxnm, kb_id):
  756. settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
  757. try_create_idx = False
  758. settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
  759. DocumentService.increment_chunk_num(
  760. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  761. return [d["id"] for d, _ in files]