Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

knowledgebase_service.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from datetime import datetime
  17. from peewee import fn
  18. from api.db import StatusEnum, TenantPermission
  19. from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
  20. from api.db.services.common_service import CommonService
  21. from api.utils import current_timestamp, datetime_format
  22. class KnowledgebaseService(CommonService):
  23. """Service class for managing knowledge base operations.
  24. This class extends CommonService to provide specialized functionality for knowledge base
  25. management, including document parsing status tracking, access control, and configuration
  26. management. It handles operations such as listing, creating, updating, and deleting
  27. knowledge bases, as well as managing their associated documents and permissions.
  28. The class implements a comprehensive set of methods for:
  29. - Document parsing status verification
  30. - Knowledge base access control
  31. - Parser configuration management
  32. - Tenant-based knowledge base organization
  33. Attributes:
  34. model: The Knowledgebase model class for database operations.
  35. """
  36. model = Knowledgebase
  37. @classmethod
  38. @DB.connection_context()
  39. def accessible4deletion(cls, kb_id, user_id):
  40. """Check if a knowledge base can be deleted by a specific user.
  41. This method verifies whether a user has permission to delete a knowledge base
  42. by checking if they are the creator of that knowledge base.
  43. Args:
  44. kb_id (str): The unique identifier of the knowledge base to check.
  45. user_id (str): The unique identifier of the user attempting the deletion.
  46. Returns:
  47. bool: True if the user has permission to delete the knowledge base,
  48. False if the user doesn't have permission or the knowledge base doesn't exist.
  49. Example:
  50. >>> KnowledgebaseService.accessible4deletion("kb123", "user456")
  51. True
  52. Note:
  53. - This method only checks creator permissions
  54. - A return value of False can mean either:
  55. 1. The knowledge base doesn't exist
  56. 2. The user is not the creator of the knowledge base
  57. """
  58. # Check if a knowledge base can be deleted by a user
  59. docs = cls.model.select(
  60. cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
  61. docs = docs.dicts()
  62. if not docs:
  63. return False
  64. return True
  65. @classmethod
  66. @DB.connection_context()
  67. def is_parsed_done(cls, kb_id):
  68. # Check if all documents in the knowledge base have completed parsing
  69. #
  70. # Args:
  71. # kb_id: Knowledge base ID
  72. #
  73. # Returns:
  74. # If all documents are parsed successfully, returns (True, None)
  75. # If any document is not fully parsed, returns (False, error_message)
  76. from api.db import TaskStatus
  77. from api.db.services.document_service import DocumentService
  78. # Get knowledge base information
  79. kbs = cls.query(id=kb_id)
  80. if not kbs:
  81. return False, "Knowledge base not found"
  82. kb = kbs[0]
  83. # Get all documents in the knowledge base
  84. docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
  85. # Check parsing status of each document
  86. for doc in docs:
  87. # If document is being parsed, don't allow chat creation
  88. if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
  89. return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
  90. # If document is not yet parsed and has no chunks, don't allow chat creation
  91. if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
  92. return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
  93. return True, None
  94. @classmethod
  95. @DB.connection_context()
  96. def list_documents_by_ids(cls, kb_ids):
  97. # Get document IDs associated with given knowledge base IDs
  98. # Args:
  99. # kb_ids: List of knowledge base IDs
  100. # Returns:
  101. # List of document IDs
  102. doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
  103. cls.model.id.in_(kb_ids)
  104. )
  105. doc_ids = list(doc_ids.dicts())
  106. doc_ids = [doc["document_id"] for doc in doc_ids]
  107. return doc_ids
  108. @classmethod
  109. @DB.connection_context()
  110. def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
  111. page_number, items_per_page,
  112. orderby, desc, keywords,
  113. parser_id=None
  114. ):
  115. # Get knowledge bases by tenant IDs with pagination and filtering
  116. # Args:
  117. # joined_tenant_ids: List of tenant IDs
  118. # user_id: Current user ID
  119. # page_number: Page number for pagination
  120. # items_per_page: Number of items per page
  121. # orderby: Field to order by
  122. # desc: Boolean indicating descending order
  123. # keywords: Search keywords
  124. # parser_id: Optional parser ID filter
  125. # Returns:
  126. # Tuple of (knowledge_base_list, total_count)
  127. fields = [
  128. cls.model.id,
  129. cls.model.avatar,
  130. cls.model.name,
  131. cls.model.language,
  132. cls.model.description,
  133. cls.model.tenant_id,
  134. cls.model.permission,
  135. cls.model.doc_num,
  136. cls.model.token_num,
  137. cls.model.chunk_num,
  138. cls.model.parser_id,
  139. cls.model.embd_id,
  140. User.nickname,
  141. User.avatar.alias('tenant_avatar'),
  142. cls.model.update_time
  143. ]
  144. if keywords:
  145. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  146. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  147. TenantPermission.TEAM.value)) | (
  148. cls.model.tenant_id == user_id))
  149. & (cls.model.status == StatusEnum.VALID.value),
  150. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  151. )
  152. else:
  153. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  154. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  155. TenantPermission.TEAM.value)) | (
  156. cls.model.tenant_id == user_id))
  157. & (cls.model.status == StatusEnum.VALID.value)
  158. )
  159. if parser_id:
  160. kbs = kbs.where(cls.model.parser_id == parser_id)
  161. if desc:
  162. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  163. else:
  164. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  165. count = kbs.count()
  166. if page_number and items_per_page:
  167. kbs = kbs.paginate(page_number, items_per_page)
  168. return list(kbs.dicts()), count
  169. @classmethod
  170. @DB.connection_context()
  171. def get_kb_ids(cls, tenant_id):
  172. # Get all knowledge base IDs for a tenant
  173. # Args:
  174. # tenant_id: Tenant ID
  175. # Returns:
  176. # List of knowledge base IDs
  177. fields = [
  178. cls.model.id,
  179. ]
  180. kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
  181. kb_ids = [kb.id for kb in kbs]
  182. return kb_ids
  183. @classmethod
  184. @DB.connection_context()
  185. def get_detail(cls, kb_id):
  186. # Get detailed information about a knowledge base
  187. # Args:
  188. # kb_id: Knowledge base ID
  189. # Returns:
  190. # Dictionary containing knowledge base details
  191. fields = [
  192. cls.model.id,
  193. cls.model.embd_id,
  194. cls.model.avatar,
  195. cls.model.name,
  196. cls.model.language,
  197. cls.model.description,
  198. cls.model.permission,
  199. cls.model.doc_num,
  200. cls.model.token_num,
  201. cls.model.chunk_num,
  202. cls.model.parser_id,
  203. cls.model.parser_config,
  204. cls.model.pagerank,
  205. cls.model.create_time,
  206. cls.model.update_time
  207. ]
  208. kbs = cls.model.select(*fields).join(Tenant, on=(
  209. (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
  210. (cls.model.id == kb_id),
  211. (cls.model.status == StatusEnum.VALID.value)
  212. )
  213. if not kbs:
  214. return
  215. d = kbs[0].to_dict()
  216. return d
  217. @classmethod
  218. @DB.connection_context()
  219. def update_parser_config(cls, id, config):
  220. # Update parser configuration for a knowledge base
  221. # Args:
  222. # id: Knowledge base ID
  223. # config: New parser configuration
  224. e, m = cls.get_by_id(id)
  225. if not e:
  226. raise LookupError(f"knowledgebase({id}) not found.")
  227. def dfs_update(old, new):
  228. # Deep update of nested configuration
  229. for k, v in new.items():
  230. if k not in old:
  231. old[k] = v
  232. continue
  233. if isinstance(v, dict):
  234. assert isinstance(old[k], dict)
  235. dfs_update(old[k], v)
  236. elif isinstance(v, list):
  237. assert isinstance(old[k], list)
  238. old[k] = list(set(old[k] + v))
  239. else:
  240. old[k] = v
  241. dfs_update(m.parser_config, config)
  242. cls.update_by_id(id, {"parser_config": m.parser_config})
  243. @classmethod
  244. @DB.connection_context()
  245. def delete_field_map(cls, id):
  246. e, m = cls.get_by_id(id)
  247. if not e:
  248. raise LookupError(f"knowledgebase({id}) not found.")
  249. m.parser_config.pop("field_map", None)
  250. cls.update_by_id(id, {"parser_config": m.parser_config})
  251. @classmethod
  252. @DB.connection_context()
  253. def get_field_map(cls, ids):
  254. # Get field mappings for knowledge bases
  255. # Args:
  256. # ids: List of knowledge base IDs
  257. # Returns:
  258. # Dictionary of field mappings
  259. conf = {}
  260. for k in cls.get_by_ids(ids):
  261. if k.parser_config and "field_map" in k.parser_config:
  262. conf.update(k.parser_config["field_map"])
  263. return conf
  264. @classmethod
  265. @DB.connection_context()
  266. def get_by_name(cls, kb_name, tenant_id):
  267. # Get knowledge base by name and tenant ID
  268. # Args:
  269. # kb_name: Knowledge base name
  270. # tenant_id: Tenant ID
  271. # Returns:
  272. # Tuple of (exists, knowledge_base)
  273. kb = cls.model.select().where(
  274. (cls.model.name == kb_name)
  275. & (cls.model.tenant_id == tenant_id)
  276. & (cls.model.status == StatusEnum.VALID.value)
  277. )
  278. if kb:
  279. return True, kb[0]
  280. return False, None
  281. @classmethod
  282. @DB.connection_context()
  283. def get_all_ids(cls):
  284. # Get all knowledge base IDs
  285. # Returns:
  286. # List of all knowledge base IDs
  287. return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
  288. @classmethod
  289. @DB.connection_context()
  290. def get_list(cls, joined_tenant_ids, user_id,
  291. page_number, items_per_page, orderby, desc, id, name):
  292. # Get list of knowledge bases with filtering and pagination
  293. # Args:
  294. # joined_tenant_ids: List of tenant IDs
  295. # user_id: Current user ID
  296. # page_number: Page number for pagination
  297. # items_per_page: Number of items per page
  298. # orderby: Field to order by
  299. # desc: Boolean indicating descending order
  300. # id: Optional ID filter
  301. # name: Optional name filter
  302. # Returns:
  303. # List of knowledge bases
  304. kbs = cls.model.select()
  305. if id:
  306. kbs = kbs.where(cls.model.id == id)
  307. if name:
  308. kbs = kbs.where(cls.model.name == name)
  309. kbs = kbs.where(
  310. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  311. TenantPermission.TEAM.value)) | (
  312. cls.model.tenant_id == user_id))
  313. & (cls.model.status == StatusEnum.VALID.value)
  314. )
  315. if desc:
  316. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  317. else:
  318. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  319. kbs = kbs.paginate(page_number, items_per_page)
  320. return list(kbs.dicts())
  321. @classmethod
  322. @DB.connection_context()
  323. def accessible(cls, kb_id, user_id):
  324. # Check if a knowledge base is accessible by a user
  325. # Args:
  326. # kb_id: Knowledge base ID
  327. # user_id: User ID
  328. # Returns:
  329. # Boolean indicating accessibility
  330. docs = cls.model.select(
  331. cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  332. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  333. docs = docs.dicts()
  334. if not docs:
  335. return False
  336. return True
  337. @classmethod
  338. @DB.connection_context()
  339. def get_kb_by_id(cls, kb_id, user_id):
  340. # Get knowledge base by ID and user ID
  341. # Args:
  342. # kb_id: Knowledge base ID
  343. # user_id: User ID
  344. # Returns:
  345. # List containing knowledge base information
  346. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  347. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  348. kbs = kbs.dicts()
  349. return list(kbs)
  350. @classmethod
  351. @DB.connection_context()
  352. def get_kb_by_name(cls, kb_name, user_id):
  353. # Get knowledge base by name and user ID
  354. # Args:
  355. # kb_name: Knowledge base name
  356. # user_id: User ID
  357. # Returns:
  358. # List containing knowledge base information
  359. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  360. ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
  361. kbs = kbs.dicts()
  362. return list(kbs)
  363. @classmethod
  364. @DB.connection_context()
  365. def atomic_increase_doc_num_by_id(cls, kb_id):
  366. data = {}
  367. data["update_time"] = current_timestamp()
  368. data["update_date"] = datetime_format(datetime.now())
  369. data["doc_num"] = cls.model.doc_num + 1
  370. num = cls.model.update(data).where(cls.model.id == kb_id).execute()
  371. return num
  372. @classmethod
  373. @DB.connection_context()
  374. def update_document_number_in_init(cls, kb_id, doc_num):
  375. """
  376. Only use this function when init system
  377. """
  378. ok, kb = cls.get_by_id(kb_id)
  379. if not ok:
  380. return
  381. kb.doc_num = doc_num
  382. dirty_fields = kb.dirty_fields
  383. if cls.model._meta.combined.get("update_time") in dirty_fields:
  384. dirty_fields.remove(cls.model._meta.combined["update_time"])
  385. if cls.model._meta.combined.get("update_date") in dirty_fields:
  386. dirty_fields.remove(cls.model._meta.combined["update_date"])
  387. try:
  388. kb.save(only=dirty_fields)
  389. except ValueError as e:
  390. if str(e) == "no data to save!":
  391. pass # that's OK
  392. else:
  393. raise e