瀏覽代碼

r2

tags/2.0.0-beta.1
jyong 5 月之前
父節點
當前提交
0486aa3445

+ 1
- 1
api/controllers/console/datasets/datasets_document.py 查看文件

@@ -664,7 +664,7 @@ class DocumentDetailApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,

+ 1
- 3
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py 查看文件

@@ -39,8 +39,6 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
@@ -542,7 +540,7 @@ class RagPipelineConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}

+ 15
- 3
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py 查看文件

@@ -12,7 +12,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus

from ..base import BaseNode
@@ -61,11 +61,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
outputs = self._get_preview_output(node_data.chunk_structure, chunks)

# retrieve knowledge
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -116,6 +116,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
#update document segment status
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)

db.session.commit()

return {

+ 26
- 26
api/services/dataset_service.py 查看文件

@@ -1,3 +1,4 @@
from calendar import day_abbr
import copy
import datetime
import json
@@ -52,7 +53,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
@@ -492,23 +492,23 @@ class DatasetService:
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False):
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
index_method = knowledge_configuration.index_method
dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality":
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=index_method.embedding_setting.embedding_provider_name,
provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=index_method.embedding_setting.embedding_model_name,
model=knowledge_configuration.embedding_model,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
@@ -516,30 +516,30 @@ class DatasetService:
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
elif index_method == "economy":
dataset.keyword_number = index_method.economy_setting.keyword_number
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique:
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique
if knowledge_configuration.index_method.indexing_technique == "economy":
if knowledge_configuration.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.index_method.indexing_technique == "high_quality":
elif knowledge_configuration.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
model=knowledge_configuration.embedding_model,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
@@ -567,7 +567,7 @@ class DatasetService:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))

# Handle new model provider from request
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
new_plugin_model_provider = knowledge_configuration.embedding_model_provider
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
@@ -575,16 +575,16 @@ class DatasetService:
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
or knowledge_configuration.embedding_model != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
model=knowledge_configuration.embedding_model,
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
@@ -608,14 +608,14 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset)
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset)
session.commit()
if action:
deal_dataset_index_update_task.delay(dataset.id, action)

@staticmethod
def delete_dataset(dataset_id, user):

+ 6
- 13
api/services/entities/knowledge_entities/rag_pipeline_entities.py 查看文件

@@ -105,18 +105,11 @@ class IndexMethod(BaseModel):

class KnowledgeConfiguration(BaseModel):
"""
Knowledge Configuration.
Knowledge Base Configuration.
"""

chunk_structure: str
index_method: IndexMethod
retrieval_setting: RetrievalSetting


class KnowledgeBaseUpdateConfiguration(BaseModel):
"""
Knowledge Base Update Configuration.
"""
index_method: IndexMethod
chunk_structure: str
retrieval_setting: RetrievalSetting
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = ""
embedding_model: Optional[str] = ""
keyword_number: Optional[int] = 10
retrieval_model: RetrievalSetting

+ 11
- 11
api/services/rag_pipeline/rag_pipeline.py 查看文件

@@ -74,7 +74,7 @@ class RagPipelineService:
result = retrieval_instance.get_pipeline_templates(language)
return result

@classmethod
@classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
"""
Get pipeline template detail.
@@ -284,7 +284,7 @@ class RagPipelineService:
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
@@ -296,8 +296,8 @@ class RagPipelineService:
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)

# update dataset
@@ -306,8 +306,8 @@ class RagPipelineService:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
)
# return new workflow
@@ -771,14 +771,14 @@ class RagPipelineService:

# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]

return workflow_node_executions
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
"""
@@ -792,5 +792,5 @@ class RagPipelineService:
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
db.session.commit()
db.session.commit()

+ 25
- 25
api/services/rag_pipeline/rag_pipeline_dsl_service.py 查看文件

@@ -1,10 +1,10 @@
import base64
from datetime import UTC, datetime
import hashlib
import json
import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import Optional, cast
from urllib.parse import urlparse
@@ -292,20 +292,20 @@ class RagPipelineDslService:
"background": icon_background,
"url": icon_url,
},
indexing_technique=knowledge_configuration.index_method.indexing_technique,
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.index_method.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@@ -314,8 +314,8 @@ class RagPipelineDslService:

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
@@ -324,13 +324,13 @@ class RagPipelineDslService:
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
knowledge_configuration.embedding_model_provider
)
elif knowledge_configuration.index_method.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
@@ -426,25 +426,25 @@ class RagPipelineDslService:
"background": icon_background,
"url": icon_url,
},
indexing_technique=knowledge_configuration.index_method.indexing_technique,
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@@ -453,8 +453,8 @@ class RagPipelineDslService:

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
@@ -463,13 +463,13 @@ class RagPipelineDslService:
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
knowledge_configuration.embedding_model_provider
)
elif knowledge_configuration.index_method.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()

Loading…
取消
儲存