jyong пре 5 месеци
родитељ
комит
797d044714

+ 0
- 13
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py Прегледај датотеку

@@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
args = parser.parse_args()

if not args.get("knowledge_base_setting"):
raise ValueError("Missing knowledge base setting.")

knowledge_base_setting_data = args.get("knowledge_base_setting")
if not knowledge_base_setting_data:
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
@@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource):
session=session,
pipeline=pipeline,
account=current_user,
knowledge_base_setting=knowledge_base_setting,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id

+ 6
- 5
api/core/plugin/impl/datasource.py Прегледај датотеку

@@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient):
"""

def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name

return json_response


+ 12
- 6
api/core/workflow/nodes/datasource/datasource_node.py Прегледај датотеку

@@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import (
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.variables import ArrayAnyVariable
@@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
},
)
case DatasourceProviderType.LOCAL_FILE:
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError(
"File is not exist"
)
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")

@@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=datasource_info.get("type", ""),
transfer_method=datasource_info.get("transfer_method", ""),
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
variable_pool.add([self.node_id, "file"], [file_info])
for key, value in datasource_info.items():
# construct new key list
new_key_list = ["file", key]
@@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file_info": file_info,
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
@@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []

def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

+ 14
- 13
api/services/dataset_service.py Прегледај датотеку

@@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import InvalidActionError, NoPermissionError
@@ -495,11 +496,11 @@ class DatasetService:
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False):
if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure
index_method = knowledge_base_setting.index_method
dataset.chunk_structure = knowledge_configuration.chunk_structure
index_method = knowledge_configuration.index_method
dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality":
model_manager = ModelManager()
@@ -519,26 +520,26 @@ class DatasetService:
dataset.keyword_number = index_method.economy_setting.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique:
# if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy":
if knowledge_configuration.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
elif knowledge_configuration.index_method.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
model=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
@@ -607,9 +608,9 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset)
session.commit()
if action:

+ 18
- 12
api/services/rag_pipeline/rag_pipeline.py Прегледај датотеку

@@ -47,7 +47,7 @@ from models.workflow import (
WorkflowType,
)
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory

@@ -262,7 +262,6 @@ class RagPipelineService:
session: Session,
pipeline: Pipeline,
account: Account,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
@@ -291,16 +290,23 @@ class RagPipelineService:
# commit db session changes
session.add(workflow)

# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_base_setting=knowledge_base_setting,
has_published=pipeline.is_published
)
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)

# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
)
# return new workflow
return workflow


Loading…
Откажи
Сачувај