Browse Source

py lint (#12102)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/0.15.0
Jyong 10 months ago
parent
commit
84ac004772
No account linked to committer's email address

+ 1
- 1
api/commands.py View File

click.echo(click.style("Starting database migration.", fg="green")) click.echo(click.style("Starting database migration.", fg="green"))


# run db migration # run db migration
import flask_migrate
import flask_migrate # type: ignore


flask_migrate.upgrade() flask_migrate.upgrade()



+ 4
- 3
api/controllers/console/datasets/datasets_document.py View File

indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()


try: try:
response = indexing_runner.indexing_estimate(
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_user.current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
"English", "English",
dataset_id, dataset_id,
) )
return estimate_response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))


return response.model_dump(), 200
return response, 200




class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource):
"English", "English",
dataset_id, dataset_id,
) )
return response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response.model_dump(), 200




class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):

+ 14
- 8
api/controllers/service_api/dataset/document.py View File

from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService




"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args # validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)


try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)


try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)


try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)

knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)


try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",

+ 7
- 7
api/core/indexing_runner.py View File

tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
preview_texts = []
preview_texts = [] # type: ignore


total_segments = 0 total_segments = 0
index_type = doc_form index_type = doc_form
if len(preview_texts) < 10: if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail( preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
question=document.page_content, answer=document.metadata.get("answer") or ""
) )
preview_texts.append(preview_detail) preview_texts.append(preview_detail)
else: else:
preview_detail = PreviewDetail(content=document.page_content)
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
if document.children: if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_texts.append(preview_detail) preview_texts.append(preview_detail)


# delete image files and related db records # delete image files and related db records


if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore


def _extract( def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )


return character_splitter
return character_splitter # type: ignore


def _split_to_documents_for_estimate( def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
# create keyword index # create keyword index
create_keyword_thread = threading.Thread( create_keyword_thread = threading.Thread(
target=self._process_keyword_index, target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
) )
create_keyword_thread.start() create_keyword_thread.start()



+ 65
- 64
api/core/rag/datasource/retrieval_service.py View File

include_segment_ids = [] include_segment_ids = []
segment_child_map = {} segment_child_map = {}
for document in documents: for document in documents:
document_id = document.metadata["document_id"]
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
) )
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else: else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
continue
else: else:
continue
else:
index_node_id = document.metadata["doc_id"]
index_node_id = document.metadata["doc_id"]


segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
) )
.first()
)


if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}


records.append(record)
records.append(record)
for record in records: for record in records:
if record["segment"].id in segment_child_map: if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)

+ 19
- 18
api/core/rag/docstore/dataset_docstore.py View File

db.session.add(segment_document) db.session.add(segment_document)
db.session.flush() db.session.flush()
if save_child: if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else: else:
segment_document.content = doc.page_content segment_document.content = doc.page_content
if doc.metadata.get("answer"): if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "") segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
if save_child and doc.children: if save_child and doc.children:
document_id=self._document_id, document_id=self._document_id,
segment_id=segment_document.id, segment_id=segment_document.id,
position=position, position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content, content=child.page_content,
word_count=len(child.page_content), word_count=len(child.page_content),
type="automatic", type="automatic",

+ 1
- 1
api/core/rag/extractor/excel_extractor.py View File

from typing import Optional, cast from typing import Optional, cast


import pandas as pd import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore


from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document

+ 1
- 1
api/core/rag/index_processor/index_processor_base.py View File

embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )


return character_splitter
return character_splitter # type: ignore

+ 6
- 0
api/core/rag/index_processor/processor/paragraph_index_processor.py View File



def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule") process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic": if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule) rules = Rule(**automatic_rule)
else: else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules")) rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes. # Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"), processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens, max_tokens=rules.segmentation.max_tokens,

+ 7
- 1
api/core/rag/index_processor/processor/parent_child_index_processor.py View File



def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule") process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules")) rules = Rule(**process_rule.get("rules"))
all_documents = []
all_documents = [] # type: ignore
if rules.parent_mode == ParentMode.PARAGRAPH: if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes. # Split the text documents into nodes.
splitter = self._get_splitter( splitter = self._get_splitter(
process_rule_mode: str, process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance], embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]: ) -> list[ChildDocument]:
if not rules.subchunk_segmentation:
raise ValueError("No subchunk segmentation found in rules.")
child_splitter = self._get_splitter( child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode, processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens, max_tokens=rules.subchunk_segmentation.max_tokens,

+ 11
- 7
api/core/rag/index_processor/processor/qa_index_processor.py View File

def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview") preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule") process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules")) rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"), processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
separator=rules.segmentation.separator if rules.segmentation else "",
embedding_model_instance=kwargs.get("embedding_model_instance"), embedding_model_instance=kwargs.get("embedding_model_instance"),
) )


all_documents.extend(split_documents) all_documents.extend(split_documents)
if preview: if preview:
self._format_qa_document( self._format_qa_document(
current_app._get_current_object(),
kwargs.get("tenant_id"),
current_app._get_current_object(), # type: ignore
kwargs.get("tenant_id"), # type: ignore
all_documents[0], all_documents[0],
all_qa_documents, all_qa_documents,
kwargs.get("doc_language", "English"), kwargs.get("doc_language", "English"),
document_format_thread = threading.Thread( document_format_thread = threading.Thread(
target=self._format_qa_document, target=self._format_qa_document,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"), # type: ignore
"document_node": doc, "document_node": doc,
"all_qa_documents": all_qa_documents, "all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"), "document_language": kwargs.get("doc_language", "English"),

+ 3
- 3
api/core/rag/models/document.py View File

from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional


from pydantic import BaseModel, Field
from pydantic import BaseModel




class ChildDocument(BaseModel): class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other """Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.). documents, etc.).
""" """
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}




class Document(BaseModel): class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other """Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.). documents, etc.).
""" """
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}


provider: Optional[str] = "dify" provider: Optional[str] = "dify"



+ 1
- 1
api/extensions/ext_blueprints.py View File

def init_app(app: DifyApp): def init_app(app: DifyApp):
# register blueprint routers # register blueprint routers


from flask_cors import CORS
from flask_cors import CORS # type: ignore


from controllers.console import bp as console_app_bp from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp from controllers.files import bp as files_bp

+ 6
- 9
api/schedule/mail_clean_document_notify_task.py View File

import logging import logging
import time import time
from collections import defaultdict


import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from flask import render_template


from extensions.ext_mail import mail from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin from models.account import Account, Tenant, TenantAccountJoin
try: try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map = {}
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs: for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)


if not tenant: if not tenant:
continue continue
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
if not current_owner_join:
continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first() account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account: if not account:
continue continue


dataset_auto_dataset_map = {}
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id dataset_auto_disable_log.document_id
document_count = len(document_ids) document_count = len(document_ids)
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")


html_content = render_template(
"clean_document_job_mail_template-US.html",
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)

end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
) )
except Exception: except Exception:
logging.exception("Send invite member mail to {} failed".format(to))
logging.exception("Send invite member mail to failed")

+ 2
- 2
api/services/app_dsl_service.py View File

from typing import Optional, cast from typing import Optional, cast
from uuid import uuid4 from uuid import uuid4


import yaml
import yaml # type: ignore
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
else: else:
cls._append_model_config_export_data(export_data, app_model) cls._append_model_config_export_data(export_data, app_model)


return yaml.dump(export_data, allow_unicode=True)
return yaml.dump(export_data, allow_unicode=True) # type: ignore


@classmethod @classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:

+ 106
- 77
api/services/dataset_service.py View File

from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs, ChildChunkUpdateArgs,
KnowledgeConfig, KnowledgeConfig,
RerankingModel,
RetrievalModel, RetrievalModel,
SegmentUpdateArgs, SegmentUpdateArgs,
) )
} }


@staticmethod @staticmethod
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)

return document
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
if document_id:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
else:
return None


@staticmethod @staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]: def get_document_by_id(document_id: str) -> Optional[Document]:
if features.billing.enabled: if features.billing.enabled:
if not knowledge_config.original_document_id: if not knowledge_config.original_document_id:
count = 0 count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

DocumentService.check_documents_upload_quota(count, features)
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

DocumentService.check_documents_upload_quota(count, features)


# if dataset is empty, update dataset data_source_type # if dataset is empty, update dataset data_source_type
if not dataset.data_source_type: if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore


if not dataset.indexing_technique: if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }


dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore


documents = [] documents = []
if knowledge_config.original_document_id: if knowledge_config.original_document_id:
# save process rule # save process rule
if not dataset_process_rule: if not dataset_process_rule:
process_rule = knowledge_config.process_rule process_rule = knowledge_config.process_rule
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json(),
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document_ids = [] document_ids = []
duplicate_document_ids = [] duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
name=file_name, name=file_name,
).first() ).first()
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
document.created_from = created_from document.created_from = created_from
document.doc_form = knowledge_config.doc_form document.doc_form = knowledge_config.doc_form
continue continue
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
position += 1 position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = knowledge_config.data_source.info_list.notion_info_list
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = [] exist_page_ids = []
exist_document = {} exist_document = {}
documents = Document.query.filter_by( documents = Document.query.filter_by(
} }
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
clean_notion_document_task.delay(list(exist_document.values()), dataset.id) clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls urls = website_info.urls
for url in urls: for url in urls:
data_source_info = { data_source_info = {
document_name = url document_name = url
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
dataset_process_rule = DatasetProcessRule( dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id, dataset_id=dataset.id,
mode=process_rule.mode, mode=process_rule.mode,
rules=process_rule.rules.model_dump_json(),
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id, created_by=account.id,
) )
elif process_rule.mode == "automatic": elif process_rule.mode == "automatic":
file_name = "" file_name = ""
data_source_info = {} data_source_info = {}
if document_data.data_source.info_list.data_source_type == "upload_file": if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.")
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
"upload_file_id": file_id, "upload_file_id": file_id,
} }
elif document_data.data_source.info_list.data_source_type == "notion_import": elif document_data.data_source.info_list.data_source_type == "notion_import":
if not document_data.data_source.info_list.notion_info_list:
raise ValueError("No notion info list found.")
notion_info_list = document_data.data_source.info_list.notion_info_list notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_info = { data_source_info = {
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_page_id": page.page_id, "notion_page_id": page.page_id,
"notion_page_icon": page.page_icon,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type, "type": page.type,
} }
elif document_data.data_source.info_list.data_source_type == "website_crawl": elif document_data.data_source.info_list.data_source_type == "website_crawl":
website_info = document_data.data_source.info_list.website_info_list website_info = document_data.data_source.info_list.website_info_list
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if website_info:
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content, # type: ignore
"mode": "crawl",
}
document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_type = document_data.data_source.info_list.data_source_type
document.data_source_info = json.dumps(data_source_info) document.data_source_info = json.dumps(data_source_info)
document.name = file_name document.name = file_name
if features.billing.enabled: if features.billing.enabled:
count = 0 count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file": if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids
if knowledge_config.data_source.info_list.file_info_list
else []
)
count = len(upload_file_list) count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
if notion_info_list:
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls)
if website_info:
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit: if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
retrieval_model = None retrieval_model = None
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_config.embedding_model_provider, knowledge_config.embedding_model
knowledge_config.embedding_model_provider, # type: ignore
knowledge_config.embedding_model, # type: ignore
) )
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
if knowledge_config.retrieval_model: if knowledge_config.retrieval_model:
retrieval_model = knowledge_config.retrieval_model retrieval_model = knowledge_config.retrieval_model
else: else:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
retrieval_model = RetrievalModel(**default_retrieval_model)
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=2,
score_threshold_enabled=False,
)
# save dataset # save dataset
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
raise ValueError("Can't update disabled segment") raise ValueError("Can't update disabled segment")
try: try:
word_count_change = segment.word_count word_count_change = segment.word_count
content = args.content
content = args.content or segment.content
if segment.content == content: if segment.content == content:
segment.word_count = len(content) segment.word_count = len(content)
if document.doc_form == "qa_model": if document.doc_form == "qa_model":
segment.answer = args.answer segment.answer = args.answer
segment.word_count += len(args.answer)
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change word_count_change = segment.word_count - word_count_change
if args.keywords: if args.keywords:
segment.keywords = args.keywords segment.keywords = args.keywords
db.session.add(document) db.session.add(document)
# update segment index task # update segment index task
if args.enabled: if args.enabled:
VectorService.create_segments_vector([args.keywords], [segment], dataset)
VectorService.create_segments_vector(
[args.keywords] if args.keywords else None,
[segment],
dataset,
document.doc_form,
)
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks # regenerate child chunks
# get embedding model instance # get embedding model instance
.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks( VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True segment, document, dataset, embedding_model_instance, processing_rule, True
) )
segment.disabled_by = None segment.disabled_by = None
if document.doc_form == "qa_model": if document.doc_form == "qa_model":
segment.answer = args.answer segment.answer = args.answer
segment.word_count += len(args.answer)
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change word_count_change = segment.word_count - word_count_change
# update document word count # update document word count
if word_count_change != 0: if word_count_change != 0:
.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks( VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True segment, document, dataset, embedding_model_instance, processing_rule, True
) )

+ 1
- 1
api/services/entities/knowledge_entities/knowledge_entities.py View File

original_document_id: Optional[str] = None original_document_id: Optional[str] = None
duplicate: bool = True duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"] indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
data_source: DataSource
process_rule: Optional[ProcessRule] = None process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model" doc_form: str = "text_model"

+ 1
- 1
api/services/hit_testing_service.py View File

db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()


return cls.compact_retrieve_response(query, all_documents)
return cls.compact_retrieve_response(query, all_documents) # type: ignore


@classmethod @classmethod
def external_retrieve( def external_retrieve(

+ 4
- 2
api/services/vector_service.py View File

.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance # get embedding model instance
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
# check embedding model setting # check embedding model setting
def generate_child_chunks( def generate_child_chunks(
cls, cls,
segment: DocumentSegment, segment: DocumentSegment,
dataset_document: Document,
dataset_document: DatasetDocument,
dataset: Dataset, dataset: Dataset,
embedding_model_instance: ModelInstance, embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule, processing_rule: DatasetProcessRule,
doc_language=dataset_document.doc_language, doc_language=dataset_document.doc_language,
) )
# save child chunks # save child chunks
if len(documents) > 0 and len(documents[0].children) > 0:
if documents and documents[0].children:
index_processor.load(dataset, documents) index_processor.load(dataset, documents)


for position, child_chunk in enumerate(documents[0].children, start=1): for position, child_chunk in enumerate(documents[0].children, start=1):

+ 2
- 1
api/tasks/batch_clean_document_task.py View File

for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try: try:
storage.delete(image_file.key)
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception: except Exception:
logging.exception( logging.exception(
"Delete image_files failed when storage deleted, \ "Delete image_files failed when storage deleted, \

Loading…
Cancel
Save