| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776 | import copy
import datetime
import json
import logging
import secrets
import time
import uuid
from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.account import Account, TenantAccountRole
from models.dataset import (
    AppDatasetJoin,
    ChildChunk,
    Dataset,
    DatasetAutoDisableLog,
    DatasetCollectionBinding,
    DatasetPermission,
    DatasetPermissionEnum,
    DatasetProcessRule,
    DatasetQuery,
    Document,
    DocumentSegment,
    ExternalKnowledgeBindings,
)
from models.model import UploadFile
from models.source import DataSourceOauthBinding
from services.entities.knowledge_entities.knowledge_entities import (
    ChildChunkUpdateArgs,
    KnowledgeConfig,
    RerankingModel,
    RetrievalModel,
    SegmentUpdateArgs,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
class DatasetService:
    @staticmethod
    def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
        query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
        if user:
            # get permitted dataset ids
            dataset_permission = (
                db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
            )
            permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
            if user.current_role == TenantAccountRole.DATASET_OPERATOR:
                # only show datasets that the user has permission to access
                if permitted_dataset_ids:
                    query = query.filter(Dataset.id.in_(permitted_dataset_ids))
                else:
                    return [], 0
            else:
                if user.current_role != TenantAccountRole.OWNER or not include_all:
                    # show all datasets that the user has permission to access
                    if permitted_dataset_ids:
                        query = query.filter(
                            db.or_(
                                Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
                                db.and_(
                                    Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
                                ),
                                db.and_(
                                    Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
                                    Dataset.id.in_(permitted_dataset_ids),
                                ),
                            )
                        )
                    else:
                        query = query.filter(
                            db.or_(
                                Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
                                db.and_(
                                    Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
                                ),
                            )
                        )
        else:
            # if no user, only show datasets that are shared with all team members
            query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
        if search:
            query = query.filter(Dataset.name.ilike(f"%{search}%"))
        if tag_ids:
            target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
            if target_ids:
                query = query.filter(Dataset.id.in_(target_ids))
            else:
                return [], 0
        datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
        return datasets.items, datasets.total
    @staticmethod
    def get_process_rules(dataset_id):
        # get the latest process rule
        dataset_process_rule = (
            db.session.query(DatasetProcessRule)
            .filter(DatasetProcessRule.dataset_id == dataset_id)
            .order_by(DatasetProcessRule.created_at.desc())
            .limit(1)
            .one_or_none()
        )
        if dataset_process_rule:
            mode = dataset_process_rule.mode
            rules = dataset_process_rule.rules_dict
        else:
            mode = DocumentService.DEFAULT_RULES["mode"]
            rules = DocumentService.DEFAULT_RULES["rules"]
        return {"mode": mode, "rules": rules}
    @staticmethod
    def get_datasets_by_ids(ids, tenant_id):
        stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
        datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
        return datasets.items, datasets.total
    @staticmethod
    def create_empty_dataset(
        tenant_id: str,
        name: str,
        description: Optional[str],
        indexing_technique: Optional[str],
        account: Account,
        permission: Optional[str] = None,
        provider: str = "vendor",
        external_knowledge_api_id: Optional[str] = None,
        external_knowledge_id: Optional[str] = None,
        embedding_model_provider: Optional[str] = None,
        embedding_model_name: Optional[str] = None,
        retrieval_model: Optional[RetrievalModel] = None,
    ):
        # check if dataset name already exists
        if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
            raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
        embedding_model = None
        if indexing_technique == "high_quality":
            model_manager = ModelManager()
            if embedding_model_provider and embedding_model_name:
                # check if embedding model setting is valid
                DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
                embedding_model = model_manager.get_model_instance(
                    tenant_id=tenant_id,
                    provider=embedding_model_provider,
                    model_type=ModelType.TEXT_EMBEDDING,
                    model=embedding_model_name,
                )
            else:
                embedding_model = model_manager.get_default_model_instance(
                    tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
                )
            if retrieval_model and retrieval_model.reranking_model:
                if (
                    retrieval_model.reranking_model.reranking_provider_name
                    and retrieval_model.reranking_model.reranking_model_name
                ):
                    # check if reranking model setting is valid
                    DatasetService.check_embedding_model_setting(
                        tenant_id,
                        retrieval_model.reranking_model.reranking_provider_name,
                        retrieval_model.reranking_model.reranking_model_name,
                    )
        dataset = Dataset(name=name, indexing_technique=indexing_technique)
        # dataset = Dataset(name=name, provider=provider, config=config)
        dataset.description = description
        dataset.created_by = account.id
        dataset.updated_by = account.id
        dataset.tenant_id = tenant_id
        dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
        dataset.embedding_model = embedding_model.model if embedding_model else None
        dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
        dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
        dataset.provider = provider
        db.session.add(dataset)
        db.session.flush()
        if provider == "external" and external_knowledge_api_id:
            external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
            if not external_knowledge_api:
                raise ValueError("External API template not found.")
            external_knowledge_binding = ExternalKnowledgeBindings(
                tenant_id=tenant_id,
                dataset_id=dataset.id,
                external_knowledge_api_id=external_knowledge_api_id,
                external_knowledge_id=external_knowledge_id,
                created_by=account.id,
            )
            db.session.add(external_knowledge_binding)
        db.session.commit()
        return dataset
    @staticmethod
    def get_dataset(dataset_id) -> Optional[Dataset]:
        dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
        return dataset
    @staticmethod
    def check_dataset_model_setting(dataset):
        if dataset.indexing_technique == "high_quality":
            try:
                model_manager = ModelManager()
                model_manager.get_model_instance(
                    tenant_id=dataset.tenant_id,
                    provider=dataset.embedding_model_provider,
                    model_type=ModelType.TEXT_EMBEDDING,
                    model=dataset.embedding_model,
                )
            except LLMBadRequestError:
                raise ValueError(
                    "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
                )
            except ProviderTokenNotInitError as ex:
                raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
    @staticmethod
    def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
        try:
            model_manager = ModelManager()
            model_manager.get_model_instance(
                tenant_id=tenant_id,
                provider=embedding_model_provider,
                model_type=ModelType.TEXT_EMBEDDING,
                model=embedding_model,
            )
        except LLMBadRequestError:
            raise ValueError(
                "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
            )
        except ProviderTokenNotInitError as ex:
            raise ValueError(ex.description)
    @staticmethod
    def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
        try:
            model_manager = ModelManager()
            model_manager.get_model_instance(
                tenant_id=tenant_id,
                provider=reranking_model_provider,
                model_type=ModelType.RERANK,
                model=reranking_model,
            )
        except LLMBadRequestError:
            raise ValueError(
                "No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
            )
        except ProviderTokenNotInitError as ex:
            raise ValueError(ex.description)
    @staticmethod
    def update_dataset(dataset_id, data, user):
        """
        Update dataset configuration and settings.
        Args:
            dataset_id: The unique identifier of the dataset to update
            data: Dictionary containing the update data
            user: The user performing the update operation
        Returns:
            Dataset: The updated dataset object
        Raises:
            ValueError: If dataset not found or validation fails
            NoPermissionError: If user lacks permission to update the dataset
        """
        # Retrieve and validate dataset existence
        dataset = DatasetService.get_dataset(dataset_id)
        if not dataset:
            raise ValueError("Dataset not found")
        # Verify user has permission to update this dataset
        DatasetService.check_dataset_permission(dataset, user)
        # Handle external dataset updates
        if dataset.provider == "external":
            return DatasetService._update_external_dataset(dataset, data, user)
        else:
            return DatasetService._update_internal_dataset(dataset, data, user)
    @staticmethod
    def _update_external_dataset(dataset, data, user):
        """
        Update external dataset configuration.
        Args:
            dataset: The dataset object to update
            data: Update data dictionary
            user: User performing the update
        Returns:
            Dataset: Updated dataset object
        """
        # Update retrieval model if provided
        external_retrieval_model = data.get("external_retrieval_model", None)
        if external_retrieval_model:
            dataset.retrieval_model = external_retrieval_model
        # Update basic dataset properties
        dataset.name = data.get("name", dataset.name)
        dataset.description = data.get("description", dataset.description)
        # Update permission if provided
        permission = data.get("permission")
        if permission:
            dataset.permission = permission
        # Validate and update external knowledge configuration
        external_knowledge_id = data.get("external_knowledge_id", None)
        external_knowledge_api_id = data.get("external_knowledge_api_id", None)
        if not external_knowledge_id:
            raise ValueError("External knowledge id is required.")
        if not external_knowledge_api_id:
            raise ValueError("External knowledge api id is required.")
        # Update metadata fields
        dataset.updated_by = user.id if user else None
        dataset.updated_at = datetime.datetime.utcnow()
        db.session.add(dataset)
        # Update external knowledge binding
        DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
        # Commit changes to database
        db.session.commit()
        return dataset
    @staticmethod
    def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
        """
        Update external knowledge binding configuration.
        Args:
            dataset_id: Dataset identifier
            external_knowledge_id: External knowledge identifier
            external_knowledge_api_id: External knowledge API identifier
        """
        with Session(db.engine) as session:
            external_knowledge_binding = (
                session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
            )
            if not external_knowledge_binding:
                raise ValueError("External knowledge binding not found.")
        # Update binding if values have changed
        if (
            external_knowledge_binding.external_knowledge_id != external_knowledge_id
            or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
        ):
            external_knowledge_binding.external_knowledge_id = external_knowledge_id
            external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
            db.session.add(external_knowledge_binding)
    @staticmethod
    def _update_internal_dataset(dataset, data, user):
        """
        Update internal dataset configuration.
        Args:
            dataset: The dataset object to update
            data: Update data dictionary
            user: User performing the update
        Returns:
            Dataset: Updated dataset object
        """
        # Remove external-specific fields from update data
        data.pop("partial_member_list", None)
        data.pop("external_knowledge_api_id", None)
        data.pop("external_knowledge_id", None)
        data.pop("external_retrieval_model", None)
        # Filter out None values except for description field
        filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
        # Handle indexing technique changes and embedding model updates
        action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
        # Add metadata fields
        filtered_data["updated_by"] = user.id
        filtered_data["updated_at"] = naive_utc_now()
        # update Retrieval model
        filtered_data["retrieval_model"] = data["retrieval_model"]
        # Update dataset in database
        db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
        db.session.commit()
        # Trigger vector index task if indexing technique changed
        if action:
            deal_dataset_vector_index_task.delay(dataset.id, action)
        return dataset
    @staticmethod
    def _handle_indexing_technique_change(dataset, data, filtered_data):
        """
        Handle changes in indexing technique and configure embedding models accordingly.
        Args:
            dataset: Current dataset object
            data: Update data dictionary
            filtered_data: Filtered update data
        Returns:
            str: Action to perform ('add', 'remove', 'update', or None)
        """
        if dataset.indexing_technique != data["indexing_technique"]:
            if data["indexing_technique"] == "economy":
                # Remove embedding model configuration for economy mode
                filtered_data["embedding_model"] = None
                filtered_data["embedding_model_provider"] = None
                filtered_data["collection_binding_id"] = None
                return "remove"
            elif data["indexing_technique"] == "high_quality":
                # Configure embedding model for high quality mode
                DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
                return "add"
        else:
            # Handle embedding model updates when indexing technique remains the same
            return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
        return None
    @staticmethod
    def _configure_embedding_model_for_high_quality(data, filtered_data):
        """
        Configure embedding model settings for high quality indexing.
        Args:
            data: Update data dictionary
            filtered_data: Filtered update data to modify
        """
        try:
            model_manager = ModelManager()
            embedding_model = model_manager.get_model_instance(
                tenant_id=current_user.current_tenant_id,
                provider=data["embedding_model_provider"],
                model_type=ModelType.TEXT_EMBEDDING,
                model=data["embedding_model"],
            )
            filtered_data["embedding_model"] = embedding_model.model
            filtered_data["embedding_model_provider"] = embedding_model.provider
            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                embedding_model.provider, embedding_model.model
            )
            filtered_data["collection_binding_id"] = dataset_collection_binding.id
        except LLMBadRequestError:
            raise ValueError(
                "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
            )
        except ProviderTokenNotInitError as ex:
            raise ValueError(ex.description)
    @staticmethod
    def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
        """
        Handle embedding model updates when indexing technique remains the same.
        Args:
            dataset: Current dataset object
            data: Update data dictionary
            filtered_data: Filtered update data to modify
        Returns:
            str: Action to perform ('update' or None)
        """
        # Skip embedding model checks if not provided in the update request
        if (
            "embedding_model_provider" not in data
            or "embedding_model" not in data
            or not data.get("embedding_model_provider")
            or not data.get("embedding_model")
        ):
            DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
            return None
        else:
            return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
    @staticmethod
    def _preserve_existing_embedding_settings(dataset, filtered_data):
        """
        Preserve existing embedding model settings when not provided in update.
        Args:
            dataset: Current dataset object
            filtered_data: Filtered update data to modify
        """
        # If the dataset already has embedding model settings, use those
        if dataset.embedding_model_provider and dataset.embedding_model:
            filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
            filtered_data["embedding_model"] = dataset.embedding_model
            # If collection_binding_id exists, keep it too
            if dataset.collection_binding_id:
                filtered_data["collection_binding_id"] = dataset.collection_binding_id
        # Otherwise, don't try to update embedding model settings at all
        # Remove these fields from filtered_data if they exist but are None/empty
        if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
            del filtered_data["embedding_model_provider"]
        if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
            del filtered_data["embedding_model"]
    @staticmethod
    def _update_embedding_model_settings(dataset, data, filtered_data):
        """
        Update embedding model settings with new values.
        Args:
            dataset: Current dataset object
            data: Update data dictionary
            filtered_data: Filtered update data to modify
        Returns:
            str: Action to perform ('update' or None)
        """
        try:
            # Compare current and new model provider settings
            current_provider_str = (
                str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
            )
            new_provider_str = (
                str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
            )
            # Only update if values are different
            if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
                DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
                return "update"
        except LLMBadRequestError:
            raise ValueError(
                "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
            )
        except ProviderTokenNotInitError as ex:
            raise ValueError(ex.description)
        return None
    @staticmethod
    def _apply_new_embedding_settings(dataset, data, filtered_data):
        """
        Apply new embedding model settings to the dataset.
        Args:
            dataset: Current dataset object
            data: Update data dictionary
            filtered_data: Filtered update data to modify
        """
        model_manager = ModelManager()
        try:
            embedding_model = model_manager.get_model_instance(
                tenant_id=current_user.current_tenant_id,
                provider=data["embedding_model_provider"],
                model_type=ModelType.TEXT_EMBEDDING,
                model=data["embedding_model"],
            )
        except ProviderTokenNotInitError:
            # If we can't get the embedding model, preserve existing settings
            logging.warning(
                f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
                f"preserving existing settings"
            )
            if dataset.embedding_model_provider and dataset.embedding_model:
                filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
                filtered_data["embedding_model"] = dataset.embedding_model
                if dataset.collection_binding_id:
                    filtered_data["collection_binding_id"] = dataset.collection_binding_id
            # Skip the rest of the embedding model update
            return
        # Apply new embedding model settings
        filtered_data["embedding_model"] = embedding_model.model
        filtered_data["embedding_model_provider"] = embedding_model.provider
        dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
            embedding_model.provider, embedding_model.model
        )
        filtered_data["collection_binding_id"] = dataset_collection_binding.id
    @staticmethod
    def delete_dataset(dataset_id, user):
        dataset = DatasetService.get_dataset(dataset_id)
        if dataset is None:
            return False
        DatasetService.check_dataset_permission(dataset, user)
        dataset_was_deleted.send(dataset)
        db.session.delete(dataset)
        db.session.commit()
        return True
    @staticmethod
    def dataset_use_check(dataset_id) -> bool:
        count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
        if count > 0:
            return True
        return False
    @staticmethod
    def check_dataset_permission(dataset, user):
        if dataset.tenant_id != user.current_tenant_id:
            logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
            raise NoPermissionError("You do not have permission to access this dataset.")
        if user.current_role != TenantAccountRole.OWNER:
            if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
                logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
                raise NoPermissionError("You do not have permission to access this dataset.")
            if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
                # For partial team permission, user needs explicit permission or be the creator
                if dataset.created_by != user.id:
                    user_permission = (
                        db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
                    )
                    if not user_permission:
                        logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
                        raise NoPermissionError("You do not have permission to access this dataset.")
    @staticmethod
    def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
        if not dataset:
            raise ValueError("Dataset not found")
        if not user:
            raise ValueError("User not found")
        if user.current_role != TenantAccountRole.OWNER:
            if dataset.permission == DatasetPermissionEnum.ONLY_ME:
                if dataset.created_by != user.id:
                    raise NoPermissionError("You do not have permission to access this dataset.")
            elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
                if not any(
                    dp.dataset_id == dataset.id
                    for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
                ):
                    raise NoPermissionError("You do not have permission to access this dataset.")
    @staticmethod
    def get_dataset_queries(dataset_id: str, page: int, per_page: int):
        stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
        dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
        return dataset_queries.items, dataset_queries.total
    @staticmethod
    def get_related_apps(dataset_id: str):
        return (
            db.session.query(AppDatasetJoin)
            .filter(AppDatasetJoin.dataset_id == dataset_id)
            .order_by(db.desc(AppDatasetJoin.created_at))
            .all()
        )
    @staticmethod
    def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
        features = FeatureService.get_features(current_user.current_tenant_id)
        if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
            return {
                "document_ids": [],
                "count": 0,
            }
        # get recent 30 days auto disable logs
        start_date = datetime.datetime.now() - datetime.timedelta(days=30)
        dataset_auto_disable_logs = (
            db.session.query(DatasetAutoDisableLog)
            .filter(
                DatasetAutoDisableLog.dataset_id == dataset_id,
                DatasetAutoDisableLog.created_at >= start_date,
            )
            .all()
        )
        if dataset_auto_disable_logs:
            return {
                "document_ids": [log.document_id for log in dataset_auto_disable_logs],
                "count": len(dataset_auto_disable_logs),
            }
        return {
            "document_ids": [],
            "count": 0,
        }
class DocumentService:
    DEFAULT_RULES: dict[str, Any] = {
        "mode": "custom",
        "rules": {
            "pre_processing_rules": [
                {"id": "remove_extra_spaces", "enabled": True},
                {"id": "remove_urls_emails", "enabled": False},
            ],
            "segmentation": {"delimiter": "\n", "max_tokens": 1024, "chunk_overlap": 50},
        },
        "limits": {
            "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH,
        },
    }
    DOCUMENT_METADATA_SCHEMA: dict[str, Any] = {
        "book": {
            "title": str,
            "language": str,
            "author": str,
            "publisher": str,
            "publication_date": str,
            "isbn": str,
            "category": str,
        },
        "web_page": {
            "title": str,
            "url": str,
            "language": str,
            "publish_date": str,
            "author/publisher": str,
            "topic/keywords": str,
            "description": str,
        },
        "paper": {
            "title": str,
            "language": str,
            "author": str,
            "publish_date": str,
            "journal/conference_name": str,
            "volume/issue/page_numbers": str,
            "doi": str,
            "topic/keywords": str,
            "abstract": str,
        },
        "social_media_post": {
            "platform": str,
            "author/username": str,
            "publish_date": str,
            "post_url": str,
            "topic/tags": str,
        },
        "wikipedia_entry": {
            "title": str,
            "language": str,
            "web_page_url": str,
            "last_edit_date": str,
            "editor/contributor": str,
            "summary/introduction": str,
        },
        "personal_document": {
            "title": str,
            "author": str,
            "creation_date": str,
            "last_modified_date": str,
            "document_type": str,
            "tags/category": str,
        },
        "business_document": {
            "title": str,
            "author": str,
            "creation_date": str,
            "last_modified_date": str,
            "document_type": str,
            "department/team": str,
        },
        "im_chat_log": {
            "chat_platform": str,
            "chat_participants/group_name": str,
            "start_date": str,
            "end_date": str,
            "summary": str,
        },
        "synced_from_notion": {
            "title": str,
            "language": str,
            "author/creator": str,
            "creation_date": str,
            "last_modified_date": str,
            "notion_page_link": str,
            "category/tags": str,
            "description": str,
        },
        "synced_from_github": {
            "repository_name": str,
            "repository_description": str,
            "repository_owner/organization": str,
            "code_filename": str,
            "code_file_path": str,
            "programming_language": str,
            "github_link": str,
            "open_source_license": str,
            "commit_date": str,
            "commit_author": str,
        },
        "others": dict,
    }
    @staticmethod
    def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
        if document_id:
            document = (
                db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
            )
            return document
        else:
            return None
    @staticmethod
    def get_document_by_id(document_id: str) -> Optional[Document]:
        document = db.session.query(Document).filter(Document.id == document_id).first()
        return document
    @staticmethod
    def get_document_by_ids(document_ids: list[str]) -> list[Document]:
        documents = (
            db.session.query(Document)
            .filter(
                Document.id.in_(document_ids),
                Document.enabled == True,
                Document.indexing_status == "completed",
                Document.archived == False,
            )
            .all()
        )
        return documents
    @staticmethod
    def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
        documents = (
            db.session.query(Document)
            .filter(
                Document.dataset_id == dataset_id,
                Document.enabled == True,
            )
            .all()
        )
        return documents
    @staticmethod
    def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
        documents = (
            db.session.query(Document)
            .filter(
                Document.dataset_id == dataset_id,
                Document.enabled == True,
                Document.indexing_status == "completed",
                Document.archived == False,
            )
            .all()
        )
        return documents
    @staticmethod
    def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
        documents = (
            db.session.query(Document)
            .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
            .all()
        )
        return documents
    @staticmethod
    def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
        documents = (
            db.session.query(Document)
            .filter(
                Document.batch == batch,
                Document.dataset_id == dataset_id,
                Document.tenant_id == current_user.current_tenant_id,
            )
            .all()
        )
        return documents
    @staticmethod
    def get_document_file_detail(file_id: str):
        file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none()
        return file_detail
    @staticmethod
    def check_archived(document):
        if document.archived:
            return True
        else:
            return False
    @staticmethod
    def delete_document(document):
        # trigger document_was_deleted signal
        file_id = None
        if document.data_source_type == "upload_file":
            if document.data_source_info:
                data_source_info = document.data_source_info_dict
                if data_source_info and "upload_file_id" in data_source_info:
                    file_id = data_source_info["upload_file_id"]
        document_was_deleted.send(
            document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id
        )
        db.session.delete(document)
        db.session.commit()
    @staticmethod
    def delete_documents(dataset: Dataset, document_ids: list[str]):
        documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
        file_ids = [
            document.data_source_info_dict["upload_file_id"]
            for document in documents
            if document.data_source_type == "upload_file"
        ]
        batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
        for document in documents:
            db.session.delete(document)
        db.session.commit()
    @staticmethod
    def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
        dataset = DatasetService.get_dataset(dataset_id)
        if not dataset:
            raise ValueError("Dataset not found.")
        document = DocumentService.get_document(dataset_id, document_id)
        if not document:
            raise ValueError("Document not found.")
        if document.tenant_id != current_user.current_tenant_id:
            raise ValueError("No permission.")
        if dataset.built_in_field_enabled:
            if document.doc_metadata:
                doc_metadata = copy.deepcopy(document.doc_metadata)
                doc_metadata[BuiltInField.document_name.value] = name
                document.doc_metadata = doc_metadata
        document.name = name
        db.session.add(document)
        db.session.commit()
        return document
    @staticmethod
    def pause_document(document):
        if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
            raise DocumentIndexingError()
        # update document to be paused
        document.is_paused = True
        document.paused_by = current_user.id
        document.paused_at = naive_utc_now()
        db.session.add(document)
        db.session.commit()
        # set document paused flag
        indexing_cache_key = "document_{}_is_paused".format(document.id)
        redis_client.setnx(indexing_cache_key, "True")
    @staticmethod
    def recover_document(document):
        if not document.is_paused:
            raise DocumentIndexingError()
        # update document to be recover
        document.is_paused = False
        document.paused_by = None
        document.paused_at = None
        db.session.add(document)
        db.session.commit()
        # delete paused flag
        indexing_cache_key = "document_{}_is_paused".format(document.id)
        redis_client.delete(indexing_cache_key)
        # trigger async task
        recover_document_indexing_task.delay(document.dataset_id, document.id)
    @staticmethod
    def retry_document(dataset_id: str, documents: list[Document]):
        for document in documents:
            # add retry flag
            retry_indexing_cache_key = "document_{}_is_retried".format(document.id)
            cache_result = redis_client.get(retry_indexing_cache_key)
            if cache_result is not None:
                raise ValueError("Document is being retried, please try again later")
            # retry document indexing
            document.indexing_status = "waiting"
            db.session.add(document)
            db.session.commit()
            redis_client.setex(retry_indexing_cache_key, 600, 1)
        # trigger async task
        document_ids = [document.id for document in documents]
        retry_document_indexing_task.delay(dataset_id, document_ids)
    @staticmethod
    def sync_website_document(dataset_id: str, document: Document):
        # add sync flag
        sync_indexing_cache_key = "document_{}_is_sync".format(document.id)
        cache_result = redis_client.get(sync_indexing_cache_key)
        if cache_result is not None:
            raise ValueError("Document is being synced, please try again later")
        # sync document indexing
        document.indexing_status = "waiting"
        data_source_info = document.data_source_info_dict
        data_source_info["mode"] = "scrape"
        document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
        db.session.add(document)
        db.session.commit()
        redis_client.setex(sync_indexing_cache_key, 600, 1)
        sync_website_document_indexing_task.delay(dataset_id, document.id)
    @staticmethod
    def get_documents_position(dataset_id):
        document = (
            db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
        )
        if document:
            return document.position + 1
        else:
            return 1
    @staticmethod
    def save_document_with_dataset_id(
        dataset: Dataset,
        knowledge_config: KnowledgeConfig,
        account: Account | Any,
        dataset_process_rule: Optional[DatasetProcessRule] = None,
        created_from: str = "web",
    ):
        # check document limit
        features = FeatureService.get_features(current_user.current_tenant_id)
        if features.billing.enabled:
            if not knowledge_config.original_document_id:
                count = 0
                if knowledge_config.data_source:
                    if knowledge_config.data_source.info_list.data_source_type == "upload_file":
                        upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                        count = len(upload_file_list)
                    elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
                        notion_info_list = knowledge_config.data_source.info_list.notion_info_list
                        for notion_info in notion_info_list:  # type: ignore
                            count = count + len(notion_info.pages)
                    elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
                        website_info = knowledge_config.data_source.info_list.website_info_list
                        count = len(website_info.urls)  # type: ignore
                    batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
                    if features.billing.subscription.plan == "sandbox" and count > 1:
                        raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
                    if count > batch_upload_limit:
                        raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
                    DocumentService.check_documents_upload_quota(count, features)
        # if dataset is empty, update dataset data_source_type
        if not dataset.data_source_type:
            dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type  # type: ignore
        if not dataset.indexing_technique:
            if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
                raise ValueError("Indexing technique is invalid")
            dataset.indexing_technique = knowledge_config.indexing_technique
            if knowledge_config.indexing_technique == "high_quality":
                model_manager = ModelManager()
                if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
                    dataset_embedding_model = knowledge_config.embedding_model
                    dataset_embedding_model_provider = knowledge_config.embedding_model_provider
                else:
                    embedding_model = model_manager.get_default_model_instance(
                        tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
                    )
                    dataset_embedding_model = embedding_model.model
                    dataset_embedding_model_provider = embedding_model.provider
                dataset.embedding_model = dataset_embedding_model
                dataset.embedding_model_provider = dataset_embedding_model_provider
                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                    dataset_embedding_model_provider, dataset_embedding_model
                )
                dataset.collection_binding_id = dataset_collection_binding.id
                if not dataset.retrieval_model:
                    default_retrieval_model = {
                        "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
                        "reranking_enable": False,
                        "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
                        "top_k": 2,
                        "score_threshold_enabled": False,
                    }
                    dataset.retrieval_model = (
                        knowledge_config.retrieval_model.model_dump()
                        if knowledge_config.retrieval_model
                        else default_retrieval_model
                    )  # type: ignore
        documents = []
        if knowledge_config.original_document_id:
            document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
            documents.append(document)
            batch = document.batch
        else:
            batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
            # save process rule
            if not dataset_process_rule:
                process_rule = knowledge_config.process_rule
                if process_rule:
                    if process_rule.mode in ("custom", "hierarchical"):
                        if process_rule.rules:
                            dataset_process_rule = DatasetProcessRule(
                                dataset_id=dataset.id,
                                mode=process_rule.mode,
                                rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
                                created_by=account.id,
                            )
                        else:
                            dataset_process_rule = dataset.latest_process_rule
                            if not dataset_process_rule:
                                raise ValueError("No process rule found.")
                    elif process_rule.mode == "automatic":
                        dataset_process_rule = DatasetProcessRule(
                            dataset_id=dataset.id,
                            mode=process_rule.mode,
                            rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                            created_by=account.id,
                        )
                    else:
                        logging.warning(
                            f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
                        )
                        return
                    db.session.add(dataset_process_rule)
                    db.session.commit()
            lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
            with redis_client.lock(lock_name, timeout=600):
                position = DocumentService.get_documents_position(dataset.id)
                document_ids = []
                duplicate_document_ids = []
                if knowledge_config.data_source.info_list.data_source_type == "upload_file":  # type: ignore
                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                    for file_id in upload_file_list:
                        file = (
                            db.session.query(UploadFile)
                            .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
                            .first()
                        )
                        # raise error if file not found
                        if not file:
                            raise FileNotExistsError()
                        file_name = file.name
                        data_source_info = {
                            "upload_file_id": file_id,
                        }
                        # check duplicate
                        if knowledge_config.duplicate:
                            document = (
                                db.session.query(Document)
                                .filter_by(
                                    dataset_id=dataset.id,
                                    tenant_id=current_user.current_tenant_id,
                                    data_source_type="upload_file",
                                    enabled=True,
                                    name=file_name,
                                )
                                .first()
                            )
                            if document:
                                document.dataset_process_rule_id = dataset_process_rule.id  # type: ignore
                                document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                                document.created_from = created_from
                                document.doc_form = knowledge_config.doc_form
                                document.doc_language = knowledge_config.doc_language
                                document.data_source_info = json.dumps(data_source_info)
                                document.batch = batch
                                document.indexing_status = "waiting"
                                db.session.add(document)
                                documents.append(document)
                                duplicate_document_ids.append(document.id)
                                continue
                        document = DocumentService.build_document(
                            dataset,
                            dataset_process_rule.id,  # type: ignore
                            knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                            knowledge_config.doc_form,
                            knowledge_config.doc_language,
                            data_source_info,
                            created_from,
                            position,
                            account,
                            file_name,
                            batch,
                        )
                        db.session.add(document)
                        db.session.flush()
                        document_ids.append(document.id)
                        documents.append(document)
                        position += 1
                elif knowledge_config.data_source.info_list.data_source_type == "notion_import":  # type: ignore
                    notion_info_list = knowledge_config.data_source.info_list.notion_info_list  # type: ignore
                    if not notion_info_list:
                        raise ValueError("No notion info list found.")
                    exist_page_ids = []
                    exist_document = {}
                    documents = (
                        db.session.query(Document)
                        .filter_by(
                            dataset_id=dataset.id,
                            tenant_id=current_user.current_tenant_id,
                            data_source_type="notion_import",
                            enabled=True,
                        )
                        .all()
                    )
                    if documents:
                        for document in documents:
                            data_source_info = json.loads(document.data_source_info)
                            exist_page_ids.append(data_source_info["notion_page_id"])
                            exist_document[data_source_info["notion_page_id"]] = document.id
                    for notion_info in notion_info_list:
                        workspace_id = notion_info.workspace_id
                        data_source_binding = (
                            db.session.query(DataSourceOauthBinding)
                            .filter(
                                db.and_(
                                    DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
                                    DataSourceOauthBinding.provider == "notion",
                                    DataSourceOauthBinding.disabled == False,
                                    DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
                                )
                            )
                            .first()
                        )
                        if not data_source_binding:
                            raise ValueError("Data source binding not found.")
                        for page in notion_info.pages:
                            if page.page_id not in exist_page_ids:
                                data_source_info = {
                                    "notion_workspace_id": workspace_id,
                                    "notion_page_id": page.page_id,
                                    "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
                                    "type": page.type,
                                }
                                # Truncate page name to 255 characters to prevent DB field length errors
                                truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
                                document = DocumentService.build_document(
                                    dataset,
                                    dataset_process_rule.id,  # type: ignore
                                    knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                                    knowledge_config.doc_form,
                                    knowledge_config.doc_language,
                                    data_source_info,
                                    created_from,
                                    position,
                                    account,
                                    truncated_page_name,
                                    batch,
                                )
                                db.session.add(document)
                                db.session.flush()
                                document_ids.append(document.id)
                                documents.append(document)
                                position += 1
                            else:
                                exist_document.pop(page.page_id)
                    # delete not selected documents
                    if len(exist_document) > 0:
                        clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
                elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":  # type: ignore
                    website_info = knowledge_config.data_source.info_list.website_info_list  # type: ignore
                    if not website_info:
                        raise ValueError("No website info list found.")
                    urls = website_info.urls
                    for url in urls:
                        data_source_info = {
                            "url": url,
                            "provider": website_info.provider,
                            "job_id": website_info.job_id,
                            "only_main_content": website_info.only_main_content,
                            "mode": "crawl",
                        }
                        if len(url) > 255:
                            document_name = url[:200] + "..."
                        else:
                            document_name = url
                        document = DocumentService.build_document(
                            dataset,
                            dataset_process_rule.id,  # type: ignore
                            knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                            knowledge_config.doc_form,
                            knowledge_config.doc_language,
                            data_source_info,
                            created_from,
                            position,
                            account,
                            document_name,
                            batch,
                        )
                        db.session.add(document)
                        db.session.flush()
                        document_ids.append(document.id)
                        documents.append(document)
                        position += 1
                db.session.commit()
                # trigger async task
                if document_ids:
                    document_indexing_task.delay(dataset.id, document_ids)
                if duplicate_document_ids:
                    duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
        return documents, batch
    @staticmethod
    def check_documents_upload_quota(count: int, features: FeatureModel):
        can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
        if count > can_upload_size:
            raise ValueError(
                f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded."
            )
    @staticmethod
    def build_document(
        dataset: Dataset,
        process_rule_id: str,
        data_source_type: str,
        document_form: str,
        document_language: str,
        data_source_info: dict,
        created_from: str,
        position: int,
        account: Account,
        name: str,
        batch: str,
    ):
        document = Document(
            tenant_id=dataset.tenant_id,
            dataset_id=dataset.id,
            position=position,
            data_source_type=data_source_type,
            data_source_info=json.dumps(data_source_info),
            dataset_process_rule_id=process_rule_id,
            batch=batch,
            name=name,
            created_from=created_from,
            created_by=account.id,
            doc_form=document_form,
            doc_language=document_language,
        )
        doc_metadata = {}
        if dataset.built_in_field_enabled:
            doc_metadata = {
                BuiltInField.document_name: name,
                BuiltInField.uploader: account.name,
                BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
                BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
                BuiltInField.source: data_source_type,
            }
        if doc_metadata:
            document.doc_metadata = doc_metadata
        return document
    @staticmethod
    def get_tenant_documents_count():
        documents_count = (
            db.session.query(Document)
            .filter(
                Document.completed_at.isnot(None),
                Document.enabled == True,
                Document.archived == False,
                Document.tenant_id == current_user.current_tenant_id,
            )
            .count()
        )
        return documents_count
    @staticmethod
    def update_document_with_dataset_id(
        dataset: Dataset,
        document_data: KnowledgeConfig,
        account: Account,
        dataset_process_rule: Optional[DatasetProcessRule] = None,
        created_from: str = "web",
    ):
        DatasetService.check_dataset_model_setting(dataset)
        document = DocumentService.get_document(dataset.id, document_data.original_document_id)
        if document is None:
            raise NotFound("Document not found")
        if document.display_status != "available":
            raise ValueError("Document is not available")
        # save process rule
        if document_data.process_rule:
            process_rule = document_data.process_rule
            if process_rule.mode in {"custom", "hierarchical"}:
                dataset_process_rule = DatasetProcessRule(
                    dataset_id=dataset.id,
                    mode=process_rule.mode,
                    rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
                    created_by=account.id,
                )
            elif process_rule.mode == "automatic":
                dataset_process_rule = DatasetProcessRule(
                    dataset_id=dataset.id,
                    mode=process_rule.mode,
                    rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                    created_by=account.id,
                )
            if dataset_process_rule is not None:
                db.session.add(dataset_process_rule)
                db.session.commit()
                document.dataset_process_rule_id = dataset_process_rule.id
        # update document data source
        if document_data.data_source:
            file_name = ""
            data_source_info = {}
            if document_data.data_source.info_list.data_source_type == "upload_file":
                if not document_data.data_source.info_list.file_info_list:
                    raise ValueError("No file info list found.")
                upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
                for file_id in upload_file_list:
                    file = (
                        db.session.query(UploadFile)
                        .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
                        .first()
                    )
                    # raise error if file not found
                    if not file:
                        raise FileNotExistsError()
                    file_name = file.name
                    data_source_info = {
                        "upload_file_id": file_id,
                    }
            elif document_data.data_source.info_list.data_source_type == "notion_import":
                if not document_data.data_source.info_list.notion_info_list:
                    raise ValueError("No notion info list found.")
                notion_info_list = document_data.data_source.info_list.notion_info_list
                for notion_info in notion_info_list:
                    workspace_id = notion_info.workspace_id
                    data_source_binding = (
                        db.session.query(DataSourceOauthBinding)
                        .filter(
                            db.and_(
                                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
                                DataSourceOauthBinding.provider == "notion",
                                DataSourceOauthBinding.disabled == False,
                                DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
                            )
                        )
                        .first()
                    )
                    if not data_source_binding:
                        raise ValueError("Data source binding not found.")
                    for page in notion_info.pages:
                        data_source_info = {
                            "notion_workspace_id": workspace_id,
                            "notion_page_id": page.page_id,
                            "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,  # type: ignore
                            "type": page.type,
                        }
            elif document_data.data_source.info_list.data_source_type == "website_crawl":
                website_info = document_data.data_source.info_list.website_info_list
                if website_info:
                    urls = website_info.urls
                    for url in urls:
                        data_source_info = {
                            "url": url,
                            "provider": website_info.provider,
                            "job_id": website_info.job_id,
                            "only_main_content": website_info.only_main_content,  # type: ignore
                            "mode": "crawl",
                        }
            document.data_source_type = document_data.data_source.info_list.data_source_type
            document.data_source_info = json.dumps(data_source_info)
            document.name = file_name
        # update document name
        if document_data.name:
            document.name = document_data.name
        # update document to be waiting
        document.indexing_status = "waiting"
        document.completed_at = None
        document.processing_started_at = None
        document.parsing_completed_at = None
        document.cleaning_completed_at = None
        document.splitting_completed_at = None
        document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
        document.created_from = created_from
        document.doc_form = document_data.doc_form
        db.session.add(document)
        db.session.commit()
        # update document segment
        update_params = {DocumentSegment.status: "re_segment"}
        db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
        db.session.commit()
        # trigger async task
        document_indexing_update_task.delay(document.dataset_id, document.id)
        return document
    @staticmethod
    def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
        features = FeatureService.get_features(current_user.current_tenant_id)
        if features.billing.enabled:
            count = 0
            if knowledge_config.data_source.info_list.data_source_type == "upload_file":  # type: ignore
                upload_file_list = (
                    knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                    if knowledge_config.data_source.info_list.file_info_list  # type: ignore
                    else []
                )
                count = len(upload_file_list)
            elif knowledge_config.data_source.info_list.data_source_type == "notion_import":  # type: ignore
                notion_info_list = knowledge_config.data_source.info_list.notion_info_list  # type: ignore
                if notion_info_list:
                    for notion_info in notion_info_list:
                        count = count + len(notion_info.pages)
            elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":  # type: ignore
                website_info = knowledge_config.data_source.info_list.website_info_list  # type: ignore
                if website_info:
                    count = len(website_info.urls)
            if features.billing.subscription.plan == "sandbox" and count > 1:
                raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
            batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
            if count > batch_upload_limit:
                raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
            DocumentService.check_documents_upload_quota(count, features)
        dataset_collection_binding_id = None
        retrieval_model = None
        if knowledge_config.indexing_technique == "high_quality":
            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                knowledge_config.embedding_model_provider,  # type: ignore
                knowledge_config.embedding_model,  # type: ignore
            )
            dataset_collection_binding_id = dataset_collection_binding.id
        if knowledge_config.retrieval_model:
            retrieval_model = knowledge_config.retrieval_model
        else:
            retrieval_model = RetrievalModel(
                search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
                reranking_enable=False,
                reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
                top_k=2,
                score_threshold_enabled=False,
            )
        # save dataset
        dataset = Dataset(
            tenant_id=tenant_id,
            name="",
            data_source_type=knowledge_config.data_source.info_list.data_source_type,  # type: ignore
            indexing_technique=knowledge_config.indexing_technique,
            created_by=account.id,
            embedding_model=knowledge_config.embedding_model,
            embedding_model_provider=knowledge_config.embedding_model_provider,
            collection_binding_id=dataset_collection_binding_id,
            retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
        )
        db.session.add(dataset)  # type: ignore
        db.session.flush()
        documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
        cut_length = 18
        cut_name = documents[0].name[:cut_length]
        dataset.name = cut_name + "..."
        dataset.description = "useful for when you want to answer queries about the " + documents[0].name
        db.session.commit()
        return dataset, documents, batch
    @classmethod
    def document_create_args_validate(cls, knowledge_config: KnowledgeConfig):
        if not knowledge_config.data_source and not knowledge_config.process_rule:
            raise ValueError("Data source or Process rule is required")
        else:
            if knowledge_config.data_source:
                DocumentService.data_source_args_validate(knowledge_config)
            if knowledge_config.process_rule:
                DocumentService.process_rule_args_validate(knowledge_config)
    @classmethod
    def data_source_args_validate(cls, knowledge_config: KnowledgeConfig):
        if not knowledge_config.data_source:
            raise ValueError("Data source is required")
        if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES:
            raise ValueError("Data source type is invalid")
        if not knowledge_config.data_source.info_list:
            raise ValueError("Data source info is required")
        if knowledge_config.data_source.info_list.data_source_type == "upload_file":
            if not knowledge_config.data_source.info_list.file_info_list:
                raise ValueError("File source info is required")
        if knowledge_config.data_source.info_list.data_source_type == "notion_import":
            if not knowledge_config.data_source.info_list.notion_info_list:
                raise ValueError("Notion source info is required")
        if knowledge_config.data_source.info_list.data_source_type == "website_crawl":
            if not knowledge_config.data_source.info_list.website_info_list:
                raise ValueError("Website source info is required")
    @classmethod
    def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig):
        if not knowledge_config.process_rule:
            raise ValueError("Process rule is required")
        if not knowledge_config.process_rule.mode:
            raise ValueError("Process rule mode is required")
        if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
            raise ValueError("Process rule mode is invalid")
        if knowledge_config.process_rule.mode == "automatic":
            knowledge_config.process_rule.rules = None
        else:
            if not knowledge_config.process_rule.rules:
                raise ValueError("Process rule rules is required")
            if knowledge_config.process_rule.rules.pre_processing_rules is None:
                raise ValueError("Process rule pre_processing_rules is required")
            unique_pre_processing_rule_dicts = {}
            for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules:
                if not pre_processing_rule.id:
                    raise ValueError("Process rule pre_processing_rules id is required")
                if not isinstance(pre_processing_rule.enabled, bool):
                    raise ValueError("Process rule pre_processing_rules enabled is invalid")
                unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule
            knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
            if not knowledge_config.process_rule.rules.segmentation:
                raise ValueError("Process rule segmentation is required")
            if not knowledge_config.process_rule.rules.segmentation.separator:
                raise ValueError("Process rule segmentation separator is required")
            if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str):
                raise ValueError("Process rule segmentation separator is invalid")
            if not (
                knowledge_config.process_rule.mode == "hierarchical"
                and knowledge_config.process_rule.rules.parent_mode == "full-doc"
            ):
                if not knowledge_config.process_rule.rules.segmentation.max_tokens:
                    raise ValueError("Process rule segmentation max_tokens is required")
                if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int):
                    raise ValueError("Process rule segmentation max_tokens is invalid")
    @classmethod
    def estimate_args_validate(cls, args: dict):
        if "info_list" not in args or not args["info_list"]:
            raise ValueError("Data source info is required")
        if not isinstance(args["info_list"], dict):
            raise ValueError("Data info is invalid")
        if "process_rule" not in args or not args["process_rule"]:
            raise ValueError("Process rule is required")
        if not isinstance(args["process_rule"], dict):
            raise ValueError("Process rule is invalid")
        if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
            raise ValueError("Process rule mode is required")
        if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
            raise ValueError("Process rule mode is invalid")
        if args["process_rule"]["mode"] == "automatic":
            args["process_rule"]["rules"] = {}
        else:
            if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
                raise ValueError("Process rule rules is required")
            if not isinstance(args["process_rule"]["rules"], dict):
                raise ValueError("Process rule rules is invalid")
            if (
                "pre_processing_rules" not in args["process_rule"]["rules"]
                or args["process_rule"]["rules"]["pre_processing_rules"] is None
            ):
                raise ValueError("Process rule pre_processing_rules is required")
            if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
                raise ValueError("Process rule pre_processing_rules is invalid")
            unique_pre_processing_rule_dicts = {}
            for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
                if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
                    raise ValueError("Process rule pre_processing_rules id is required")
                if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
                    raise ValueError("Process rule pre_processing_rules id is invalid")
                if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
                    raise ValueError("Process rule pre_processing_rules enabled is required")
                if not isinstance(pre_processing_rule["enabled"], bool):
                    raise ValueError("Process rule pre_processing_rules enabled is invalid")
                unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
            args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
            if (
                "segmentation" not in args["process_rule"]["rules"]
                or args["process_rule"]["rules"]["segmentation"] is None
            ):
                raise ValueError("Process rule segmentation is required")
            if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
                raise ValueError("Process rule segmentation is invalid")
            if (
                "separator" not in args["process_rule"]["rules"]["segmentation"]
                or not args["process_rule"]["rules"]["segmentation"]["separator"]
            ):
                raise ValueError("Process rule segmentation separator is required")
            if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
                raise ValueError("Process rule segmentation separator is invalid")
            if (
                "max_tokens" not in args["process_rule"]["rules"]["segmentation"]
                or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
            ):
                raise ValueError("Process rule segmentation max_tokens is required")
            if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
                raise ValueError("Process rule segmentation max_tokens is invalid")
    @staticmethod
    def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user):
        """
        Batch update document status.
        Args:
            dataset (Dataset): The dataset object
            document_ids (list[str]): List of document IDs to update
            action (str): Action to perform (enable, disable, archive, un_archive)
            user: Current user performing the action
        Raises:
            DocumentIndexingError: If document is being indexed or not in correct state
            ValueError: If action is invalid
        """
        if not document_ids:
            return
        # Early validation of action parameter
        valid_actions = ["enable", "disable", "archive", "un_archive"]
        if action not in valid_actions:
            raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}")
        documents_to_update = []
        # First pass: validate all documents and prepare updates
        for document_id in document_ids:
            document = DocumentService.get_document(dataset.id, document_id)
            if not document:
                continue
            # Check if document is being indexed
            indexing_cache_key = f"document_{document.id}_indexing"
            cache_result = redis_client.get(indexing_cache_key)
            if cache_result is not None:
                raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later")
            # Prepare update based on action
            update_info = DocumentService._prepare_document_status_update(document, action, user)
            if update_info:
                documents_to_update.append(update_info)
        # Second pass: apply all updates in a single transaction
        if documents_to_update:
            try:
                for update_info in documents_to_update:
                    document = update_info["document"]
                    updates = update_info["updates"]
                    # Apply updates to the document
                    for field, value in updates.items():
                        setattr(document, field, value)
                    db.session.add(document)
                # Batch commit all changes
                db.session.commit()
            except Exception as e:
                # Rollback on any error
                db.session.rollback()
                raise e
            # Execute async tasks and set Redis cache after successful commit
            # propagation_error is used to capture any errors for submitting async task execution
            propagation_error = None
            for update_info in documents_to_update:
                try:
                    # Execute async tasks after successful commit
                    if update_info["async_task"]:
                        task_info = update_info["async_task"]
                        task_func = task_info["function"]
                        task_args = task_info["args"]
                        task_func.delay(*task_args)
                except Exception as e:
                    # Log the error but do not rollback the transaction
                    logging.exception(f"Error executing async task for document {update_info['document'].id}")
                    # don't raise the error immediately, but capture it for later
                    propagation_error = e
                try:
                    # Set Redis cache if needed after successful commit
                    if update_info["set_cache"]:
                        document = update_info["document"]
                        indexing_cache_key = f"document_{document.id}_indexing"
                        redis_client.setex(indexing_cache_key, 600, 1)
                except Exception as e:
                    # Log the error but do not rollback the transaction
                    logging.exception(f"Error setting cache for document {update_info['document'].id}")
            # Raise any propagation error after all updates
            if propagation_error:
                raise propagation_error
    @staticmethod
    def _prepare_document_status_update(document, action: str, user):
        """
        Prepare document status update information.
        Args:
            document: Document object to update
            action: Action to perform
            user: Current user
        Returns:
            dict: Update information or None if no update needed
        """
        now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
        if action == "enable":
            return DocumentService._prepare_enable_update(document, now)
        elif action == "disable":
            return DocumentService._prepare_disable_update(document, user, now)
        elif action == "archive":
            return DocumentService._prepare_archive_update(document, user, now)
        elif action == "un_archive":
            return DocumentService._prepare_unarchive_update(document, now)
        return None
    @staticmethod
    def _prepare_enable_update(document, now):
        """Prepare updates for enabling a document."""
        if document.enabled:
            return None
        return {
            "document": document,
            "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now},
            "async_task": {"function": add_document_to_index_task, "args": [document.id]},
            "set_cache": True,
        }
    @staticmethod
    def _prepare_disable_update(document, user, now):
        """Prepare updates for disabling a document."""
        if not document.completed_at or document.indexing_status != "completed":
            raise DocumentIndexingError(f"Document: {document.name} is not completed.")
        if not document.enabled:
            return None
        return {
            "document": document,
            "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now},
            "async_task": {"function": remove_document_from_index_task, "args": [document.id]},
            "set_cache": True,
        }
    @staticmethod
    def _prepare_archive_update(document, user, now):
        """Prepare updates for archiving a document."""
        if document.archived:
            return None
        update_info = {
            "document": document,
            "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now},
            "async_task": None,
            "set_cache": False,
        }
        # Only set async task and cache if document is currently enabled
        if document.enabled:
            update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]}
            update_info["set_cache"] = True
        return update_info
    @staticmethod
    def _prepare_unarchive_update(document, now):
        """Prepare updates for unarchiving a document."""
        if not document.archived:
            return None
        update_info = {
            "document": document,
            "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now},
            "async_task": None,
            "set_cache": False,
        }
        # Only re-index if the document is currently enabled
        if document.enabled:
            update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]}
            update_info["set_cache"] = True
        return update_info
class SegmentService:
    @classmethod
    def segment_create_args_validate(cls, args: dict, document: Document):
        if document.doc_form == "qa_model":
            if "answer" not in args or not args["answer"]:
                raise ValueError("Answer is required")
            if not args["answer"].strip():
                raise ValueError("Answer is empty")
        if "content" not in args or not args["content"] or not args["content"].strip():
            raise ValueError("Content is empty")
    @classmethod
    def create_segment(cls, args: dict, document: Document, dataset: Dataset):
        content = args["content"]
        doc_id = str(uuid.uuid4())
        segment_hash = helper.generate_text_hash(content)
        tokens = 0
        if dataset.indexing_technique == "high_quality":
            model_manager = ModelManager()
            embedding_model = model_manager.get_model_instance(
                tenant_id=current_user.current_tenant_id,
                provider=dataset.embedding_model_provider,
                model_type=ModelType.TEXT_EMBEDDING,
                model=dataset.embedding_model,
            )
            # calc embedding use tokens
            tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
        lock_name = "add_segment_lock_document_id_{}".format(document.id)
        with redis_client.lock(lock_name, timeout=600):
            max_position = (
                db.session.query(func.max(DocumentSegment.position))
                .filter(DocumentSegment.document_id == document.id)
                .scalar()
            )
            segment_document = DocumentSegment(
                tenant_id=current_user.current_tenant_id,
                dataset_id=document.dataset_id,
                document_id=document.id,
                index_node_id=doc_id,
                index_node_hash=segment_hash,
                position=max_position + 1 if max_position else 1,
                content=content,
                word_count=len(content),
                tokens=tokens,
                status="completed",
                indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                created_by=current_user.id,
            )
            if document.doc_form == "qa_model":
                segment_document.word_count += len(args["answer"])
                segment_document.answer = args["answer"]
            db.session.add(segment_document)
            # update document word count
            document.word_count += segment_document.word_count
            db.session.add(document)
            db.session.commit()
            # save vector index
            try:
                VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
            except Exception as e:
                logging.exception("create segment index failed")
                segment_document.enabled = False
                segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                segment_document.status = "error"
                segment_document.error = str(e)
                db.session.commit()
            segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
            return segment
    @classmethod
    def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
        lock_name = "multi_add_segment_lock_document_id_{}".format(document.id)
        increment_word_count = 0
        with redis_client.lock(lock_name, timeout=600):
            embedding_model = None
            if dataset.indexing_technique == "high_quality":
                model_manager = ModelManager()
                embedding_model = model_manager.get_model_instance(
                    tenant_id=current_user.current_tenant_id,
                    provider=dataset.embedding_model_provider,
                    model_type=ModelType.TEXT_EMBEDDING,
                    model=dataset.embedding_model,
                )
            max_position = (
                db.session.query(func.max(DocumentSegment.position))
                .filter(DocumentSegment.document_id == document.id)
                .scalar()
            )
            pre_segment_data_list = []
            segment_data_list = []
            keywords_list = []
            position = max_position + 1 if max_position else 1
            for segment_item in segments:
                content = segment_item["content"]
                doc_id = str(uuid.uuid4())
                segment_hash = helper.generate_text_hash(content)
                tokens = 0
                if dataset.indexing_technique == "high_quality" and embedding_model:
                    # calc embedding use tokens
                    if document.doc_form == "qa_model":
                        tokens = embedding_model.get_text_embedding_num_tokens(
                            texts=[content + segment_item["answer"]]
                        )[0]
                    else:
                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
                segment_document = DocumentSegment(
                    tenant_id=current_user.current_tenant_id,
                    dataset_id=document.dataset_id,
                    document_id=document.id,
                    index_node_id=doc_id,
                    index_node_hash=segment_hash,
                    position=position,
                    content=content,
                    word_count=len(content),
                    tokens=tokens,
                    keywords=segment_item.get("keywords", []),
                    status="completed",
                    indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                    completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
                    created_by=current_user.id,
                )
                if document.doc_form == "qa_model":
                    segment_document.answer = segment_item["answer"]
                    segment_document.word_count += len(segment_item["answer"])
                increment_word_count += segment_document.word_count
                db.session.add(segment_document)
                segment_data_list.append(segment_document)
                position += 1
                pre_segment_data_list.append(segment_document)
                if "keywords" in segment_item:
                    keywords_list.append(segment_item["keywords"])
                else:
                    keywords_list.append(None)
            # update document word count
            document.word_count += increment_word_count
            db.session.add(document)
            try:
                # save vector index
                VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
            except Exception as e:
                logging.exception("create segment index failed")
                for segment_document in segment_data_list:
                    segment_document.enabled = False
                    segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                    segment_document.status = "error"
                    segment_document.error = str(e)
            db.session.commit()
            return segment_data_list
    @classmethod
    def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
        indexing_cache_key = "segment_{}_indexing".format(segment.id)
        cache_result = redis_client.get(indexing_cache_key)
        if cache_result is not None:
            raise ValueError("Segment is indexing, please try again later")
        if args.enabled is not None:
            action = args.enabled
            if segment.enabled != action:
                if not action:
                    segment.enabled = action
                    segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                    segment.disabled_by = current_user.id
                    db.session.add(segment)
                    db.session.commit()
                    # Set cache to prevent indexing the same segment multiple times
                    redis_client.setex(indexing_cache_key, 600, 1)
                    disable_segment_from_index_task.delay(segment.id)
                    return segment
        if not segment.enabled:
            if args.enabled is not None:
                if not args.enabled:
                    raise ValueError("Can't update disabled segment")
            else:
                raise ValueError("Can't update disabled segment")
        try:
            word_count_change = segment.word_count
            content = args.content or segment.content
            if segment.content == content:
                segment.word_count = len(content)
                if document.doc_form == "qa_model":
                    segment.answer = args.answer
                    segment.word_count += len(args.answer) if args.answer else 0
                word_count_change = segment.word_count - word_count_change
                keyword_changed = False
                if args.keywords:
                    if Counter(segment.keywords) != Counter(args.keywords):
                        segment.keywords = args.keywords
                        keyword_changed = True
                segment.enabled = True
                segment.disabled_at = None
                segment.disabled_by = None
                db.session.add(segment)
                db.session.commit()
                # update document word count
                if word_count_change != 0:
                    document.word_count = max(0, document.word_count + word_count_change)
                    db.session.add(document)
                # update segment index task
                if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
                    # regenerate child chunks
                    # get embedding model instance
                    if dataset.indexing_technique == "high_quality":
                        # check embedding model setting
                        model_manager = ModelManager()
                        if dataset.embedding_model_provider:
                            embedding_model_instance = model_manager.get_model_instance(
                                tenant_id=dataset.tenant_id,
                                provider=dataset.embedding_model_provider,
                                model_type=ModelType.TEXT_EMBEDDING,
                                model=dataset.embedding_model,
                            )
                        else:
                            embedding_model_instance = model_manager.get_default_model_instance(
                                tenant_id=dataset.tenant_id,
                                model_type=ModelType.TEXT_EMBEDDING,
                            )
                    else:
                        raise ValueError("The knowledge base index technique is not high quality!")
                    # get the process rule
                    processing_rule = (
                        db.session.query(DatasetProcessRule)
                        .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
                        .first()
                    )
                    if not processing_rule:
                        raise ValueError("No processing rule found.")
                    VectorService.generate_child_chunks(
                        segment, document, dataset, embedding_model_instance, processing_rule, True
                    )
                elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
                    if args.enabled or keyword_changed:
                        # update segment vector index
                        VectorService.update_segment_vector(args.keywords, segment, dataset)
            else:
                segment_hash = helper.generate_text_hash(content)
                tokens = 0
                if dataset.indexing_technique == "high_quality":
                    model_manager = ModelManager()
                    embedding_model = model_manager.get_model_instance(
                        tenant_id=current_user.current_tenant_id,
                        provider=dataset.embedding_model_provider,
                        model_type=ModelType.TEXT_EMBEDDING,
                        model=dataset.embedding_model,
                    )
                    # calc embedding use tokens
                    if document.doc_form == "qa_model":
                        segment.answer = args.answer
                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
                    else:
                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
                segment.content = content
                segment.index_node_hash = segment_hash
                segment.word_count = len(content)
                segment.tokens = tokens
                segment.status = "completed"
                segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                segment.updated_by = current_user.id
                segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                segment.enabled = True
                segment.disabled_at = None
                segment.disabled_by = None
                if document.doc_form == "qa_model":
                    segment.answer = args.answer
                    segment.word_count += len(args.answer) if args.answer else 0
                word_count_change = segment.word_count - word_count_change
                # update document word count
                if word_count_change != 0:
                    document.word_count = max(0, document.word_count + word_count_change)
                    db.session.add(document)
                db.session.add(segment)
                db.session.commit()
                if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
                    # get embedding model instance
                    if dataset.indexing_technique == "high_quality":
                        # check embedding model setting
                        model_manager = ModelManager()
                        if dataset.embedding_model_provider:
                            embedding_model_instance = model_manager.get_model_instance(
                                tenant_id=dataset.tenant_id,
                                provider=dataset.embedding_model_provider,
                                model_type=ModelType.TEXT_EMBEDDING,
                                model=dataset.embedding_model,
                            )
                        else:
                            embedding_model_instance = model_manager.get_default_model_instance(
                                tenant_id=dataset.tenant_id,
                                model_type=ModelType.TEXT_EMBEDDING,
                            )
                    else:
                        raise ValueError("The knowledge base index technique is not high quality!")
                    # get the process rule
                    processing_rule = (
                        db.session.query(DatasetProcessRule)
                        .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
                        .first()
                    )
                    if not processing_rule:
                        raise ValueError("No processing rule found.")
                    VectorService.generate_child_chunks(
                        segment, document, dataset, embedding_model_instance, processing_rule, True
                    )
                elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
                    # update segment vector index
                    VectorService.update_segment_vector(args.keywords, segment, dataset)
        except Exception as e:
            logging.exception("update segment index failed")
            segment.enabled = False
            segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
            segment.status = "error"
            segment.error = str(e)
            db.session.commit()
        new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
        return new_segment
    @classmethod
    def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
        indexing_cache_key = "segment_{}_delete_indexing".format(segment.id)
        cache_result = redis_client.get(indexing_cache_key)
        if cache_result is not None:
            raise ValueError("Segment is deleting.")
        # enabled segment need to delete index
        if segment.enabled:
            # send delete segment index task
            redis_client.setex(indexing_cache_key, 600, 1)
            delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
        db.session.delete(segment)
        # update document word count
        document.word_count -= segment.word_count
        db.session.add(document)
        db.session.commit()
    @classmethod
    def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
        index_node_ids = (
            db.session.query(DocumentSegment)
            .with_entities(DocumentSegment.index_node_id)
            .filter(
                DocumentSegment.id.in_(segment_ids),
                DocumentSegment.dataset_id == dataset.id,
                DocumentSegment.document_id == document.id,
                DocumentSegment.tenant_id == current_user.current_tenant_id,
            )
            .all()
        )
        index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
        delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
        db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
        db.session.commit()
    @classmethod
    def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
        if action == "enable":
            segments = (
                db.session.query(DocumentSegment)
                .filter(
                    DocumentSegment.id.in_(segment_ids),
                    DocumentSegment.dataset_id == dataset.id,
                    DocumentSegment.document_id == document.id,
                    DocumentSegment.enabled == False,
                )
                .all()
            )
            if not segments:
                return
            real_deal_segmment_ids = []
            for segment in segments:
                indexing_cache_key = "segment_{}_indexing".format(segment.id)
                cache_result = redis_client.get(indexing_cache_key)
                if cache_result is not None:
                    continue
                segment.enabled = True
                segment.disabled_at = None
                segment.disabled_by = None
                db.session.add(segment)
                real_deal_segmment_ids.append(segment.id)
            db.session.commit()
            enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
        elif action == "disable":
            segments = (
                db.session.query(DocumentSegment)
                .filter(
                    DocumentSegment.id.in_(segment_ids),
                    DocumentSegment.dataset_id == dataset.id,
                    DocumentSegment.document_id == document.id,
                    DocumentSegment.enabled == True,
                )
                .all()
            )
            if not segments:
                return
            real_deal_segmment_ids = []
            for segment in segments:
                indexing_cache_key = "segment_{}_indexing".format(segment.id)
                cache_result = redis_client.get(indexing_cache_key)
                if cache_result is not None:
                    continue
                segment.enabled = False
                segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                segment.disabled_by = current_user.id
                db.session.add(segment)
                real_deal_segmment_ids.append(segment.id)
            db.session.commit()
            disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
        else:
            raise InvalidActionError()
    @classmethod
    def create_child_chunk(
        cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
    ) -> ChildChunk:
        lock_name = "add_child_lock_{}".format(segment.id)
        with redis_client.lock(lock_name, timeout=20):
            index_node_id = str(uuid.uuid4())
            index_node_hash = helper.generate_text_hash(content)
            child_chunk_count = (
                db.session.query(ChildChunk)
                .filter(
                    ChildChunk.tenant_id == current_user.current_tenant_id,
                    ChildChunk.dataset_id == dataset.id,
                    ChildChunk.document_id == document.id,
                    ChildChunk.segment_id == segment.id,
                )
                .count()
            )
            max_position = (
                db.session.query(func.max(ChildChunk.position))
                .filter(
                    ChildChunk.tenant_id == current_user.current_tenant_id,
                    ChildChunk.dataset_id == dataset.id,
                    ChildChunk.document_id == document.id,
                    ChildChunk.segment_id == segment.id,
                )
                .scalar()
            )
            child_chunk = ChildChunk(
                tenant_id=current_user.current_tenant_id,
                dataset_id=dataset.id,
                document_id=document.id,
                segment_id=segment.id,
                position=max_position + 1 if max_position else 1,
                index_node_id=index_node_id,
                index_node_hash=index_node_hash,
                content=content,
                word_count=len(content),
                type="customized",
                created_by=current_user.id,
            )
            db.session.add(child_chunk)
            # save vector index
            try:
                VectorService.create_child_chunk_vector(child_chunk, dataset)
            except Exception as e:
                logging.exception("create child chunk index failed")
                db.session.rollback()
                raise ChildChunkIndexingError(str(e))
            db.session.commit()
            return child_chunk
    @classmethod
    def update_child_chunks(
        cls,
        child_chunks_update_args: list[ChildChunkUpdateArgs],
        segment: DocumentSegment,
        document: Document,
        dataset: Dataset,
    ) -> list[ChildChunk]:
        child_chunks = (
            db.session.query(ChildChunk)
            .filter(
                ChildChunk.dataset_id == dataset.id,
                ChildChunk.document_id == document.id,
                ChildChunk.segment_id == segment.id,
            )
            .all()
        )
        child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
        new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
        for child_chunk_update_args in child_chunks_update_args:
            if child_chunk_update_args.id:
                child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None)
                if child_chunk:
                    if child_chunk.content != child_chunk_update_args.content:
                        child_chunk.content = child_chunk_update_args.content
                        child_chunk.word_count = len(child_chunk.content)
                        child_chunk.updated_by = current_user.id
                        child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                        child_chunk.type = "customized"
                        update_child_chunks.append(child_chunk)
            else:
                new_child_chunks_args.append(child_chunk_update_args)
        if child_chunks_map:
            delete_child_chunks = list(child_chunks_map.values())
        try:
            if update_child_chunks:
                db.session.bulk_save_objects(update_child_chunks)
            if delete_child_chunks:
                for child_chunk in delete_child_chunks:
                    db.session.delete(child_chunk)
            if new_child_chunks_args:
                child_chunk_count = len(child_chunks)
                for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
                    index_node_id = str(uuid.uuid4())
                    index_node_hash = helper.generate_text_hash(args.content)
                    child_chunk = ChildChunk(
                        tenant_id=current_user.current_tenant_id,
                        dataset_id=dataset.id,
                        document_id=document.id,
                        segment_id=segment.id,
                        position=position,
                        index_node_id=index_node_id,
                        index_node_hash=index_node_hash,
                        content=args.content,
                        word_count=len(args.content),
                        type="customized",
                        created_by=current_user.id,
                    )
                    db.session.add(child_chunk)
                    db.session.flush()
                    new_child_chunks.append(child_chunk)
            VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset)
            db.session.commit()
        except Exception as e:
            logging.exception("update child chunk index failed")
            db.session.rollback()
            raise ChildChunkIndexingError(str(e))
        return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
    @classmethod
    def update_child_chunk(
        cls,
        content: str,
        child_chunk: ChildChunk,
        segment: DocumentSegment,
        document: Document,
        dataset: Dataset,
    ) -> ChildChunk:
        try:
            child_chunk.content = content
            child_chunk.word_count = len(content)
            child_chunk.updated_by = current_user.id
            child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
            child_chunk.type = "customized"
            db.session.add(child_chunk)
            VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
            db.session.commit()
        except Exception as e:
            logging.exception("update child chunk index failed")
            db.session.rollback()
            raise ChildChunkIndexingError(str(e))
        return child_chunk
    @classmethod
    def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset):
        db.session.delete(child_chunk)
        try:
            VectorService.delete_child_chunk_vector(child_chunk, dataset)
        except Exception as e:
            logging.exception("delete child chunk index failed")
            db.session.rollback()
            raise ChildChunkDeleteIndexError(str(e))
        db.session.commit()
    @classmethod
    def get_child_chunks(
        cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
    ):
        query = (
            select(ChildChunk)
            .filter_by(
                tenant_id=current_user.current_tenant_id,
                dataset_id=dataset_id,
                document_id=document_id,
                segment_id=segment_id,
            )
            .order_by(ChildChunk.position.asc())
        )
        if keyword:
            query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
        return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
    @classmethod
    def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
        """Get a child chunk by its ID."""
        result = (
            db.session.query(ChildChunk)
            .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
            .first()
        )
        return result if isinstance(result, ChildChunk) else None
    @classmethod
    def get_segments(
        cls,
        document_id: str,
        tenant_id: str,
        status_list: list[str] | None = None,
        keyword: str | None = None,
        page: int = 1,
        limit: int = 20,
    ):
        """Get segments for a document with optional filtering."""
        query = select(DocumentSegment).filter(
            DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
        )
        if status_list:
            query = query.filter(DocumentSegment.status.in_(status_list))
        if keyword:
            query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
        query = query.order_by(DocumentSegment.position.asc())
        paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
        return paginated_segments.items, paginated_segments.total
    @classmethod
    def update_segment_by_id(
        cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
    ) -> tuple[DocumentSegment, Document]:
        """Update a segment by its ID with validation and checks."""
        # check dataset
        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
        if not dataset:
            raise NotFound("Dataset not found.")
        # check user's model setting
        DatasetService.check_dataset_model_setting(dataset)
        # check document
        document = DocumentService.get_document(dataset_id, document_id)
        if not document:
            raise NotFound("Document not found.")
        # check embedding model setting if high quality
        if dataset.indexing_technique == "high_quality":
            try:
                model_manager = ModelManager()
                model_manager.get_model_instance(
                    tenant_id=user_id,
                    provider=dataset.embedding_model_provider,
                    model_type=ModelType.TEXT_EMBEDDING,
                    model=dataset.embedding_model,
                )
            except LLMBadRequestError:
                raise ValueError(
                    "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
                )
            except ProviderTokenNotInitError as ex:
                raise ValueError(ex.description)
        # check segment
        segment = (
            db.session.query(DocumentSegment)
            .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
            .first()
        )
        if not segment:
            raise NotFound("Segment not found.")
        # validate and update segment
        cls.segment_create_args_validate(segment_data, document)
        updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
        return updated_segment, document
    @classmethod
    def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
        """Get a segment by its ID."""
        result = (
            db.session.query(DocumentSegment)
            .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
            .first()
        )
        return result if isinstance(result, DocumentSegment) else None
class DatasetCollectionBindingService:
    @classmethod
    def get_dataset_collection_binding(
        cls, provider_name: str, model_name: str, collection_type: str = "dataset"
    ) -> DatasetCollectionBinding:
        dataset_collection_binding = (
            db.session.query(DatasetCollectionBinding)
            .filter(
                DatasetCollectionBinding.provider_name == provider_name,
                DatasetCollectionBinding.model_name == model_name,
                DatasetCollectionBinding.type == collection_type,
            )
            .order_by(DatasetCollectionBinding.created_at)
            .first()
        )
        if not dataset_collection_binding:
            dataset_collection_binding = DatasetCollectionBinding(
                provider_name=provider_name,
                model_name=model_name,
                collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
                type=collection_type,
            )
            db.session.add(dataset_collection_binding)
            db.session.commit()
        return dataset_collection_binding
    @classmethod
    def get_dataset_collection_binding_by_id_and_type(
        cls, collection_binding_id: str, collection_type: str = "dataset"
    ) -> DatasetCollectionBinding:
        dataset_collection_binding = (
            db.session.query(DatasetCollectionBinding)
            .filter(
                DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
            )
            .order_by(DatasetCollectionBinding.created_at)
            .first()
        )
        if not dataset_collection_binding:
            raise ValueError("Dataset collection binding not found")
        return dataset_collection_binding
class DatasetPermissionService:
    @classmethod
    def get_dataset_partial_member_list(cls, dataset_id):
        user_list_query = (
            db.session.query(
                DatasetPermission.account_id,
            )
            .filter(DatasetPermission.dataset_id == dataset_id)
            .all()
        )
        user_list = []
        for user in user_list_query:
            user_list.append(user.account_id)
        return user_list
    @classmethod
    def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
        try:
            db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
            permissions = []
            for user in user_list:
                permission = DatasetPermission(
                    tenant_id=tenant_id,
                    dataset_id=dataset_id,
                    account_id=user["user_id"],
                )
                permissions.append(permission)
            db.session.add_all(permissions)
            db.session.commit()
        except Exception as e:
            db.session.rollback()
            raise e
    @classmethod
    def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list):
        if not user.is_dataset_editor:
            raise NoPermissionError("User does not have permission to edit this dataset.")
        if user.is_dataset_operator and dataset.permission != requested_permission:
            raise NoPermissionError("Dataset operators cannot change the dataset permissions.")
        if user.is_dataset_operator and requested_permission == "partial_members":
            if not requested_partial_member_list:
                raise ValueError("Partial member list is required when setting to partial members.")
            local_member_list = cls.get_dataset_partial_member_list(dataset.id)
            request_member_list = [user["user_id"] for user in requested_partial_member_list]
            if set(local_member_list) != set(request_member_list):
                raise ValueError("Dataset operators cannot change the dataset permissions.")
    @classmethod
    def clear_partial_member_list(cls, dataset_id):
        try:
            db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
            db.session.commit()
        except Exception as e:
            db.session.rollback()
            raise e
 |