### What problem does this PR solve? Resolve document concurrent upload issue. #6039 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.18.0
| @@ -13,33 +13,30 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import xxhash | |||
| import json | |||
| import logging | |||
| import random | |||
| import re | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from copy import deepcopy | |||
| from datetime import datetime | |||
| from io import BytesIO | |||
| import trio | |||
| import trio | |||
| import xxhash | |||
| from peewee import fn | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api import settings | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| from rag.nlp import search, rag_tokenizer | |||
| from api.db import FileType, TaskStatus, ParserType, LLMType | |||
| from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant | |||
| from api.db.db_models import Document | |||
| from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus | |||
| from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db import StatusEnum | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from rag.nlp import rag_tokenizer, search | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| class DocumentService(CommonService): | |||
| @@ -96,9 +93,7 @@ class DocumentService(CommonService): | |||
| def insert(cls, doc): | |||
| if not cls.save(**doc): | |||
| raise RuntimeError("Database error (Document)!") | |||
| e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) | |||
| if not KnowledgebaseService.update_by_id( | |||
| kb.id, {"doc_num": kb.doc_num + 1}): | |||
| if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]): | |||
| raise RuntimeError("Database error (Knowledgebase)!") | |||
| return Document(**doc) | |||
| @@ -174,9 +169,9 @@ class DocumentService(CommonService): | |||
| "Document not found which is supposed to be there") | |||
| num = Knowledgebase.update( | |||
| token_num=Knowledgebase.token_num + | |||
| token_num, | |||
| token_num, | |||
| chunk_num=Knowledgebase.chunk_num + | |||
| chunk_num).where( | |||
| chunk_num).where( | |||
| Knowledgebase.id == kb_id).execute() | |||
| return num | |||
| @@ -192,9 +187,9 @@ class DocumentService(CommonService): | |||
| "Document not found which is supposed to be there") | |||
| num = Knowledgebase.update( | |||
| token_num=Knowledgebase.token_num - | |||
| token_num, | |||
| token_num, | |||
| chunk_num=Knowledgebase.chunk_num - | |||
| chunk_num | |||
| chunk_num | |||
| ).where( | |||
| Knowledgebase.id == kb_id).execute() | |||
| return num | |||
| @@ -207,9 +202,9 @@ class DocumentService(CommonService): | |||
| num = Knowledgebase.update( | |||
| token_num=Knowledgebase.token_num - | |||
| doc.token_num, | |||
| doc.token_num, | |||
| chunk_num=Knowledgebase.chunk_num - | |||
| doc.chunk_num, | |||
| doc.chunk_num, | |||
| doc_num=Knowledgebase.doc_num - 1 | |||
| ).where( | |||
| Knowledgebase.id == doc.kb_id).execute() | |||
| @@ -221,7 +216,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select( | |||
| Knowledgebase.tenant_id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) | |||
| docs = docs.dicts() | |||
| if not docs: | |||
| @@ -243,7 +238,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select( | |||
| Knowledgebase.tenant_id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) | |||
| docs = docs.dicts() | |||
| if not docs: | |||
| @@ -256,7 +251,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select( | |||
| cls.model.id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id) | |||
| Knowledgebase.id == cls.model.kb_id) | |||
| ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) | |||
| ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) | |||
| docs = docs.dicts() | |||
| @@ -270,7 +265,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select( | |||
| cls.model.id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id) | |||
| Knowledgebase.id == cls.model.kb_id) | |||
| ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) | |||
| docs = docs.dicts() | |||
| if not docs: | |||
| @@ -283,7 +278,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select( | |||
| Knowledgebase.embd_id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) | |||
| docs = docs.dicts() | |||
| if not docs: | |||
| @@ -306,9 +301,9 @@ class DocumentService(CommonService): | |||
| Tenant.asr_id, | |||
| Tenant.llm_id, | |||
| ) | |||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) | |||
| .where(cls.model.id == doc_id) | |||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) | |||
| .where(cls.model.id == doc_id) | |||
| ) | |||
| configs = configs.dicts() | |||
| if not configs: | |||
| @@ -374,6 +369,7 @@ class DocumentService(CommonService): | |||
| "progress_msg": "Task is queued...", | |||
| "process_begin_at": get_format_time() | |||
| }) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def update_meta_fields(cls, doc_id, meta_fields): | |||
| @@ -425,7 +421,7 @@ class DocumentService(CommonService): | |||
| info = { | |||
| "process_duation": datetime.timestamp( | |||
| datetime.now()) - | |||
| d["process_begin_at"].timestamp(), | |||
| d["process_begin_at"].timestamp(), | |||
| "run": status} | |||
| if prg != 0: | |||
| info["progress"] = prg | |||
| @@ -480,13 +476,13 @@ def queue_raptor_o_graphrag_tasks(doc, ty): | |||
| def doc_upload_and_parse(conversation_id, file_objs, user_id): | |||
| from rag.app import presentation, picture, naive, audio, email | |||
| from api.db.services.api_service import API4ConversationService | |||
| from api.db.services.conversation_service import ConversationService | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import TenantService | |||
| from api.db.services.api_service import API4ConversationService | |||
| from api.db.services.conversation_service import ConversationService | |||
| from rag.app import audio, email, naive, picture, presentation | |||
| e, conv = ConversationService.get_by_id(conversation_id) | |||
| if not e: | |||
| @@ -13,26 +13,30 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from datetime import datetime | |||
| from peewee import fn | |||
| from api.db import StatusEnum, TenantPermission | |||
| from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document | |||
| from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant | |||
| from api.db.services.common_service import CommonService | |||
| from peewee import fn | |||
| from api.utils import current_timestamp, datetime_format | |||
| class KnowledgebaseService(CommonService): | |||
| """Service class for managing knowledge base operations. | |||
| This class extends CommonService to provide specialized functionality for knowledge base | |||
| management, including document parsing status tracking, access control, and configuration | |||
| management. It handles operations such as listing, creating, updating, and deleting | |||
| knowledge bases, as well as managing their associated documents and permissions. | |||
| The class implements a comprehensive set of methods for: | |||
| - Document parsing status verification | |||
| - Knowledge base access control | |||
| - Parser configuration management | |||
| - Tenant-based knowledge base organization | |||
| Attributes: | |||
| model: The Knowledgebase model class for database operations. | |||
| """ | |||
| @@ -42,22 +46,22 @@ class KnowledgebaseService(CommonService): | |||
| @DB.connection_context() | |||
| def accessible4deletion(cls, kb_id, user_id): | |||
| """Check if a knowledge base can be deleted by a specific user. | |||
| This method verifies whether a user has permission to delete a knowledge base | |||
| by checking if they are the creator of that knowledge base. | |||
| Args: | |||
| kb_id (str): The unique identifier of the knowledge base to check. | |||
| user_id (str): The unique identifier of the user attempting the deletion. | |||
| Returns: | |||
| bool: True if the user has permission to delete the knowledge base, | |||
| False if the user doesn't have permission or the knowledge base doesn't exist. | |||
| Example: | |||
| >>> KnowledgebaseService.accessible4deletion("kb123", "user456") | |||
| True | |||
| Note: | |||
| - This method only checks creator permissions | |||
| - A return value of False can mean either: | |||
| @@ -76,25 +80,25 @@ class KnowledgebaseService(CommonService): | |||
| @DB.connection_context() | |||
| def is_parsed_done(cls, kb_id): | |||
| # Check if all documents in the knowledge base have completed parsing | |||
| # | |||
| # | |||
| # Args: | |||
| # kb_id: Knowledge base ID | |||
| # | |||
| # | |||
| # Returns: | |||
| # If all documents are parsed successfully, returns (True, None) | |||
| # If any document is not fully parsed, returns (False, error_message) | |||
| from api.db import TaskStatus | |||
| from api.db.services.document_service import DocumentService | |||
| # Get knowledge base information | |||
| kbs = cls.query(id=kb_id) | |||
| if not kbs: | |||
| return False, "Knowledge base not found" | |||
| kb = kbs[0] | |||
| # Get all documents in the knowledge base | |||
| docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "") | |||
| # Check parsing status of each document | |||
| for doc in docs: | |||
| # If document is being parsed, don't allow chat creation | |||
| @@ -103,21 +107,21 @@ class KnowledgebaseService(CommonService): | |||
| # If document is not yet parsed and has no chunks, don't allow chat creation | |||
| if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0: | |||
| return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat." | |||
| return True, None | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def list_documents_by_ids(cls,kb_ids): | |||
| def list_documents_by_ids(cls, kb_ids): | |||
| # Get document IDs associated with given knowledge base IDs | |||
| # Args: | |||
| # kb_ids: List of knowledge base IDs | |||
| # Returns: | |||
| # List of document IDs | |||
| doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( | |||
| doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where( | |||
| cls.model.id.in_(kb_ids) | |||
| ) | |||
| doc_ids =list(doc_ids.dicts()) | |||
| doc_ids = list(doc_ids.dicts()) | |||
| doc_ids = [doc["document_id"] for doc in doc_ids] | |||
| return doc_ids | |||
| @@ -222,7 +226,7 @@ class KnowledgebaseService(CommonService): | |||
| cls.model.parser_config, | |||
| cls.model.pagerank] | |||
| kbs = cls.model.select(*fields).join(Tenant, on=( | |||
| (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( | |||
| (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( | |||
| (cls.model.id == kb_id), | |||
| (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| @@ -324,7 +328,7 @@ class KnowledgebaseService(CommonService): | |||
| kbs = kbs.where( | |||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||
| TenantPermission.TEAM.value)) | ( | |||
| cls.model.tenant_id == user_id)) | |||
| cls.model.tenant_id == user_id)) | |||
| & (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if desc: | |||
| @@ -347,7 +351,7 @@ class KnowledgebaseService(CommonService): | |||
| # Boolean indicating accessibility | |||
| docs = cls.model.select( | |||
| cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) | |||
| ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) | |||
| ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) | |||
| docs = docs.dicts() | |||
| if not docs: | |||
| return False | |||
| @@ -363,7 +367,7 @@ class KnowledgebaseService(CommonService): | |||
| # Returns: | |||
| # List containing knowledge base information | |||
| kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) | |||
| ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) | |||
| ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) | |||
| kbs = kbs.dicts() | |||
| return list(kbs) | |||
| @@ -377,7 +381,16 @@ class KnowledgebaseService(CommonService): | |||
| # Returns: | |||
| # List containing knowledge base information | |||
| kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) | |||
| ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) | |||
| ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) | |||
| kbs = kbs.dicts() | |||
| return list(kbs) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def atomic_increase_doc_num_by_id(cls, kb_id): | |||
| data = {} | |||
| data["update_time"] = current_timestamp() | |||
| data["update_date"] = datetime_format(datetime.now()) | |||
| data["doc_num"] = cls.model.doc_num + 1 | |||
| num = cls.model.update(data).where(cls.model.id == kb_id).execute() | |||
| return num | |||