# Conflicts: # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/workflow/entities/node_entities.py # api/core/workflow/enums.pytags/2.0.0-beta.1
| @@ -1,6 +1,6 @@ | |||
| #!/bin/bash | |||
| npm add -g pnpm@10.8.0 | |||
| npm add -g pnpm@10.11.1 | |||
| cd web && pnpm install | |||
| pipx install uv | |||
| @@ -846,6 +846,9 @@ def clear_orphaned_file_records(force: bool): | |||
| {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, | |||
| {"type": "text", "table": "conversations", "column": "introduction"}, | |||
| {"type": "text", "table": "conversations", "column": "system_instruction"}, | |||
| {"type": "text", "table": "accounts", "column": "avatar"}, | |||
| {"type": "text", "table": "apps", "column": "icon"}, | |||
| {"type": "text", "table": "sites", "column": "icon"}, | |||
| {"type": "json", "table": "messages", "column": "inputs"}, | |||
| {"type": "json", "table": "messages", "column": "message"}, | |||
| ] | |||
| @@ -60,8 +60,7 @@ class NacosHttpClient: | |||
| sign_str = tenant + "+" | |||
| if group: | |||
| sign_str = sign_str + group + "+" | |||
| if sign_str: | |||
| sign_str += ts | |||
| sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. | |||
| return sign_str | |||
| def get_access_token(self, force_refresh=False): | |||
| @@ -6,12 +6,12 @@ from sqlalchemy.orm import Session | |||
| from controllers.console import api | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from extensions.ext_database import db | |||
| from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from libs.login import login_required | |||
| from models import App | |||
| from models.model import AppMode | |||
| from models.workflow import WorkflowRunStatus | |||
| from services.workflow_app_service import WorkflowAppService | |||
| @@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource): | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| args.status = WorkflowRunStatus(args.status) if args.status else None | |||
| args.status = WorkflowExecutionStatus(args.status) if args.status else None | |||
| if args.created_at__before: | |||
| args.created_at__before = isoparse(args.created_at__before) | |||
| @@ -24,12 +24,13 @@ from core.errors.error import ( | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from extensions.ext_database import db | |||
| from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from libs import helper | |||
| from libs.helper import TimestampField | |||
| from models.model import App, AppMode, EndUser | |||
| from models.workflow import WorkflowRun, WorkflowRunStatus | |||
| from models.workflow import WorkflowRun | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| from services.workflow_app_service import WorkflowAppService | |||
| @@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource): | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| args.status = WorkflowRunStatus(args.status) if args.status else None | |||
| args.status = WorkflowExecutionStatus(args.status) if args.status else None | |||
| if args.created_at__before: | |||
| args.created_at__before = isoparse(args.created_at__before) | |||
| @@ -1,19 +1,21 @@ | |||
| from flask import request | |||
| from flask_restful import marshal, reqparse | |||
| from flask_restful import marshal, marshal_with, reqparse | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services.dataset_service | |||
| from controllers.service_api import api | |||
| from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from fields.tag_fields import tag_fields | |||
| from libs.login import current_user | |||
| from models.dataset import Dataset, DatasetPermissionEnum | |||
| from services.dataset_service import DatasetPermissionService, DatasetService | |||
| from services.entities.knowledge_entities.knowledge_entities import RetrievalModel | |||
| from services.tag_service import TagService | |||
| def _validate_name(name): | |||
| @@ -320,5 +322,134 @@ class DatasetApi(DatasetApiResource): | |||
| raise DatasetInUseError() | |||
| class DatasetTagsApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| @marshal_with(tag_fields) | |||
| def get(self, _, dataset_id): | |||
| """Get all knowledge type tags.""" | |||
| tags = TagService.get_tags("knowledge", current_user.current_tenant_id) | |||
| return tags, 200 | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| """Add a knowledge type tag.""" | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=DatasetTagsApi._validate_tag_name, | |||
| ) | |||
| args = parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| tag = TagService.save_tags(args) | |||
| response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} | |||
| return response, 200 | |||
| @validate_dataset_token | |||
| def patch(self, _, dataset_id): | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=DatasetTagsApi._validate_tag_name, | |||
| ) | |||
| parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| args = parser.parse_args() | |||
| tag = TagService.update_tags(args, args.get("tag_id")) | |||
| binding_count = TagService.get_tag_binding_count(args.get("tag_id")) | |||
| response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} | |||
| return response, 200 | |||
| @validate_dataset_token | |||
| def delete(self, _, dataset_id): | |||
| """Delete a knowledge type tag.""" | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| args = parser.parse_args() | |||
| TagService.delete_tag(args.get("tag_id")) | |||
| return 204 | |||
| @staticmethod | |||
| def _validate_tag_name(name): | |||
| if not name or len(name) < 1 or len(name) > 50: | |||
| raise ValueError("Name must be between 1 to 50 characters.") | |||
| return name | |||
| class DatasetTagBindingApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." | |||
| ) | |||
| parser.add_argument( | |||
| "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." | |||
| ) | |||
| args = parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| TagService.save_tag_binding(args) | |||
| return 204 | |||
| class DatasetTagUnbindingApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") | |||
| parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") | |||
| args = parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| TagService.delete_tag_binding(args) | |||
| return 204 | |||
| class DatasetTagsBindingStatusApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def get(self, _, *args, **kwargs): | |||
| """Get all knowledge type tags.""" | |||
| dataset_id = kwargs.get("dataset_id") | |||
| tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) | |||
| tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] | |||
| response = {"data": tags_list, "total": len(tags)} | |||
| return response, 200 | |||
| api.add_resource(DatasetListApi, "/datasets") | |||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |||
| api.add_resource(DatasetTagsApi, "/datasets/tags") | |||
| api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") | |||
| api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") | |||
| api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags") | |||
| @@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| ) | |||
| return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| def get(self, tenant_id, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| class ChildChunkApi(DatasetApiResource): | |||
| """Resource for child chunks.""" | |||
| @@ -70,7 +70,7 @@ class ModelConfigConverter: | |||
| if not model_mode: | |||
| model_mode = LLMMode.CHAT.value | |||
| if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): | |||
| model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value | |||
| model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value | |||
| if not model_schema: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| @@ -27,8 +27,8 @@ from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| @@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, | |||
| } | |||
| # init variable pool | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator, Mapping | |||
| @@ -57,26 +56,23 @@ from core.app.entities.task_entities import ( | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models import Conversation, EndUser, Message, MessageFile | |||
| from models.account import Account | |||
| from models.enums import CreatorUserRole | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowRunStatus, | |||
| ) | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -126,8 +122,14 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| workflow_info=CycleManagerWorkflowInfo( | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType(workflow.type), | |||
| version=workflow.version, | |||
| graph_data=workflow.graph_dict, | |||
| ), | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| @@ -137,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| self._task_state = WorkflowTaskState() | |||
| self._message_cycle_manager = MessageCycleManage( | |||
| self._message_cycle_manager = MessageCycleManager( | |||
| application_generate_entity=application_generate_entity, task_state=self._task_state | |||
| ) | |||
| @@ -158,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| :return: | |||
| """ | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( | |||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | |||
| ) | |||
| @@ -302,15 +304,12 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| ) | |||
| self._workflow_run_id = workflow_execution.id | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() | |||
| self._workflow_run_id = workflow_execution.id_ | |||
| message = self._get_message(session=session) | |||
| if not message: | |||
| raise ValueError(f"Message not found: {self._message_id}") | |||
| message.workflow_run_id = workflow_execution.id | |||
| message.workflow_run_id = workflow_execution.id_ | |||
| workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| @@ -550,7 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED, | |||
| status=WorkflowExecutionStatus.FAILED, | |||
| error_message=event.error, | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| @@ -576,7 +575,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| status=WorkflowExecutionStatus.STOPPED, | |||
| error_message=event.get_stop_reason(), | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| @@ -604,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._message_cycle_manager._handle_retriever_resources(event) | |||
| self._message_cycle_manager.handle_retriever_resources(event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| session.commit() | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| self._message_cycle_manager._handle_annotation_reply(event) | |||
| self._message_cycle_manager.handle_annotation_reply(event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| session.commit() | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| @@ -636,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._message_cycle_manager._message_to_stream_response( | |||
| yield self._message_cycle_manager.message_to_stream_response( | |||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| # published by moderation | |||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=event.text, reason=event.reason | |||
| ) | |||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | |||
| @@ -653,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| if output_moderation_answer: | |||
| self._task_state.answer = output_moderation_answer | |||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=output_moderation_answer, | |||
| reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, | |||
| ) | |||
| @@ -682,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| message = self._get_message(session=session) | |||
| message.answer = self._task_state.answer | |||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| message_files = [ | |||
| MessageFile( | |||
| message_id=message.id, | |||
| @@ -712,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| message.answer_price_unit = usage.completion_price_unit | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| self._task_state.metadata["usage"] = jsonable_encoder(usage) | |||
| self._task_state.metadata.usage = usage | |||
| else: | |||
| self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) | |||
| self._task_state.metadata.usage = LLMUsage.empty_usage() | |||
| message_was_created.send( | |||
| message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| @@ -725,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| Message end to stream response. | |||
| :return: | |||
| """ | |||
| extras = {} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata.copy() | |||
| extras = self._task_state.metadata.model_dump() | |||
| if "annotation_reply" in extras["metadata"]: | |||
| del extras["metadata"]["annotation_reply"] | |||
| if self._task_state.metadata.annotation_reply: | |||
| del extras["annotation_reply"] | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message_id, | |||
| files=self._recorded_files, | |||
| metadata=extras.get("metadata", {}), | |||
| metadata=extras, | |||
| ) | |||
| def _handle_output_moderation_chunk(self, text: str) -> bool: | |||
| @@ -44,15 +44,14 @@ from core.app.entities.task_entities import ( | |||
| ) | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.workflow.entities.node_execution_entities import NodeExecution | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| EndUser, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowRun, | |||
| ) | |||
| @@ -73,11 +72,10 @@ class WorkflowResponseConverter: | |||
| ) -> WorkflowStartStreamResponse: | |||
| return WorkflowStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| workflow_run_id=workflow_execution.id_, | |||
| data=WorkflowStartStreamResponse.Data( | |||
| id=workflow_execution.id, | |||
| id=workflow_execution.id_, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| sequence_number=workflow_execution.sequence_number, | |||
| inputs=workflow_execution.inputs, | |||
| created_at=int(workflow_execution.started_at.timestamp()), | |||
| ), | |||
| @@ -91,7 +89,7 @@ class WorkflowResponseConverter: | |||
| workflow_execution: WorkflowExecution, | |||
| ) -> WorkflowFinishStreamResponse: | |||
| created_by = None | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) | |||
| assert workflow_run is not None | |||
| if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: | |||
| stmt = select(Account).where(Account.id == workflow_run.created_by) | |||
| @@ -122,11 +120,10 @@ class WorkflowResponseConverter: | |||
| return WorkflowFinishStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| workflow_run_id=workflow_execution.id_, | |||
| data=WorkflowFinishStreamResponse.Data( | |||
| id=workflow_execution.id, | |||
| id=workflow_execution.id_, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| sequence_number=workflow_execution.sequence_number, | |||
| status=workflow_execution.status, | |||
| outputs=workflow_execution.outputs, | |||
| error=workflow_execution.error_message, | |||
| @@ -146,16 +143,16 @@ class WorkflowResponseConverter: | |||
| *, | |||
| event: QueueNodeStartedEvent, | |||
| task_id: str, | |||
| workflow_node_execution: NodeExecution, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeStartStreamResponse]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| if not workflow_node_execution.workflow_execution_id: | |||
| return None | |||
| response = NodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| workflow_run_id=workflow_node_execution.workflow_execution_id, | |||
| data=NodeStartStreamResponse.Data( | |||
| id=workflow_node_execution.id, | |||
| node_id=workflow_node_execution.node_id, | |||
| @@ -196,18 +193,18 @@ class WorkflowResponseConverter: | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| task_id: str, | |||
| workflow_node_execution: NodeExecution, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| if not workflow_node_execution.workflow_execution_id: | |||
| return None | |||
| if not workflow_node_execution.finished_at: | |||
| return None | |||
| return NodeFinishStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| workflow_run_id=workflow_node_execution.workflow_execution_id, | |||
| data=NodeFinishStreamResponse.Data( | |||
| id=workflow_node_execution.id, | |||
| node_id=workflow_node_execution.node_id, | |||
| @@ -239,18 +236,18 @@ class WorkflowResponseConverter: | |||
| *, | |||
| event: QueueNodeRetryEvent, | |||
| task_id: str, | |||
| workflow_node_execution: NodeExecution, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| if not workflow_node_execution.workflow_execution_id: | |||
| return None | |||
| if not workflow_node_execution.finished_at: | |||
| return None | |||
| return NodeRetryStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| workflow_run_id=workflow_node_execution.workflow_execution_id, | |||
| data=NodeRetryStreamResponse.Data( | |||
| id=workflow_node_execution.id, | |||
| node_id=workflow_node_execution.node_id, | |||
| @@ -25,8 +25,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| @@ -132,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| trace_manager=trace_manager, | |||
| workflow_run_id=workflow_run_id, | |||
| workflow_execution_id=workflow_run_id, | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| @@ -279,7 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | |||
| node_id=node_id, inputs=args["inputs"] | |||
| ), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| workflow_execution_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -355,7 +355,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras={"auto_generate_conversation_name": False}, | |||
| single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| workflow_execution_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, | |||
| } | |||
| variable_pool = VariablePool( | |||
| @@ -50,16 +50,15 @@ from core.app.entities.task_entities import ( | |||
| WorkflowAppStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatorUserRole | |||
| @@ -69,7 +68,6 @@ from models.workflow import ( | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| @@ -114,8 +112,14 @@ class WorkflowAppGenerateTaskPipeline: | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, | |||
| }, | |||
| workflow_info=CycleManagerWorkflowInfo( | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType(workflow.type), | |||
| version=workflow.version, | |||
| graph_data=workflow.graph_dict, | |||
| ), | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| @@ -125,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| ) | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._task_state = WorkflowTaskState() | |||
| self._workflow_run_id = "" | |||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| @@ -266,17 +268,13 @@ class WorkflowAppGenerateTaskPipeline: | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| ) | |||
| self._workflow_run_id = workflow_execution.id | |||
| start_resp = self._workflow_response_converter.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| # init workflow run | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() | |||
| self._workflow_run_id = workflow_execution.id_ | |||
| start_resp = self._workflow_response_converter.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| yield start_resp | |||
| elif isinstance( | |||
| @@ -511,9 +509,9 @@ class WorkflowAppGenerateTaskPipeline: | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| status=WorkflowExecutionStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| else WorkflowExecutionStatus.STOPPED, | |||
| error_message=event.error | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else event.get_stop_reason(), | |||
| @@ -542,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if tts_publisher: | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._text_chunk_to_stream_response( | |||
| delta_text, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| @@ -557,7 +554,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) | |||
| assert workflow_run is not None | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| @@ -29,8 +29,8 @@ from core.app.entities.queue_entities import ( | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| AgentLogEvent, | |||
| GraphEngineEvent, | |||
| @@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| inputs: Mapping[str, Any] | None = {} | |||
| process_data: Mapping[str, Any] | None = {} | |||
| outputs: Mapping[str, Any] | None = {} | |||
| execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} | |||
| execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {} | |||
| if node_run_result: | |||
| inputs = node_run_result.inputs | |||
| process_data = node_run_result.process_data | |||
| @@ -77,6 +77,8 @@ class AppGenerateEntity(BaseModel): | |||
| App Generate Entity. | |||
| """ | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| task_id: str | |||
| # app config | |||
| @@ -100,9 +102,6 @@ class AppGenerateEntity(BaseModel): | |||
| # tracing instance | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| class EasyUIBasedAppGenerateEntity(AppGenerateEntity): | |||
| """ | |||
| @@ -206,7 +205,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| # app config | |||
| app_config: WorkflowUIBasedAppConfig | |||
| workflow_run_id: str | |||
| workflow_execution_id: str | |||
| class SingleIterationRunEntity(BaseModel): | |||
| """ | |||
| @@ -1,4 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import Any, Optional | |||
| @@ -6,7 +6,9 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base import BaseNodeData | |||
| @@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES | |||
| retriever_resources: list[dict] | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| @@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: Optional[str] = None | |||
| """single iteration duration map""" | |||
| @@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| retry_index: int # retry index | |||
| @@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| @@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| @@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| @@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent): | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| @@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| class AnnotationReplyAccount(BaseModel): | |||
| id: str | |||
| name: str | |||
| class AnnotationReply(BaseModel): | |||
| id: str | |||
| account: AnnotationReplyAccount | |||
| class TaskStateMetadata(BaseModel): | |||
| annotation_reply: AnnotationReply | None = None | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) | |||
| usage: LLMUsage | None = None | |||
| class TaskState(BaseModel): | |||
| @@ -15,7 +32,7 @@ class TaskState(BaseModel): | |||
| TaskState entity | |||
| """ | |||
| metadata: dict = {} | |||
| metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) | |||
| class EasyUITaskState(TaskState): | |||
| @@ -189,7 +206,6 @@ class WorkflowStartStreamResponse(StreamResponse): | |||
| id: str | |||
| workflow_id: str | |||
| sequence_number: int | |||
| inputs: Mapping[str, Any] | |||
| created_at: int | |||
| @@ -210,7 +226,6 @@ class WorkflowFinishStreamResponse(StreamResponse): | |||
| id: str | |||
| workflow_id: str | |||
| sequence_number: int | |||
| status: str | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| error: Optional[str] = None | |||
| @@ -307,7 +322,7 @@ class NodeFinishStreamResponse(StreamResponse): | |||
| status: str | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| @@ -376,7 +391,7 @@ class NodeRetryStreamResponse(StreamResponse): | |||
| status: str | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| @@ -43,7 +42,7 @@ from core.app.entities.task_entities import ( | |||
| StreamResponse, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| @@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| @@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought | |||
| logger = logging.getLogger(__name__) | |||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): | |||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| """ | |||
| EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| @@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ) | |||
| ) | |||
| self._message_cycle_manager = MessageCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| task_state=self._task_state, | |||
| ) | |||
| self._conversation_name_generate_thread: Optional[Thread] = None | |||
| def process( | |||
| @@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ]: | |||
| if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" | |||
| ) | |||
| @@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if isinstance(stream_response, ErrorStreamResponse): | |||
| raise stream_response.err | |||
| elif isinstance(stream_response, MessageEndStreamResponse): | |||
| extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} | |||
| extras = {"usage": self._task_state.llm_result.usage.model_dump()} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata | |||
| extras["metadata"] = self._task_state.metadata.model_dump() | |||
| response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] | |||
| if self._conversation_mode == AppMode.COMPLETION.value: | |||
| response = CompletionAppBlockingResponse( | |||
| @@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ) | |||
| if output_moderation_answer: | |||
| self._task_state.llm_result.message.content = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=output_moderation_answer | |||
| ) | |||
| with Session(db.engine) as session: | |||
| # Save message | |||
| @@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message_end_resp = self._message_end_to_stream_response() | |||
| yield message_end_resp | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| self._message_cycle_manager.handle_retriever_resources(event) | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| annotation = self._handle_annotation_reply(event) | |||
| annotation = self._message_cycle_manager.handle_annotation_reply(event) | |||
| if annotation: | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, QueueAgentThoughtEvent): | |||
| @@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if agent_thought_response is not None: | |||
| yield agent_thought_response | |||
| elif isinstance(event, QueueMessageFileEvent): | |||
| response = self._message_file_to_stream_response(event) | |||
| response = self._message_cycle_manager.message_file_to_stream_response(event) | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): | |||
| @@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| self._task_state.llm_result.message.content = current_content | |||
| if isinstance(event, QueueLLMChunkEvent): | |||
| yield self._message_to_stream_response( | |||
| yield self._message_cycle_manager.message_to_stream_response( | |||
| answer=cast(str, delta_text), | |||
| message_id=self._message_id, | |||
| ) | |||
| @@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message_id=self._message_id, | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| yield self._message_replace_to_stream_response(answer=event.text) | |||
| yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| else: | |||
| @@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| Message end to stream response. | |||
| :return: | |||
| """ | |||
| self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) | |||
| extras = {} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata | |||
| self._task_state.metadata.usage = self._task_state.llm_result.usage | |||
| metadata_dict = self._task_state.metadata.model_dump() | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message_id, | |||
| metadata=extras.get("metadata", {}), | |||
| metadata=metadata_dict, | |||
| ) | |||
| def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: | |||
| @@ -17,6 +17,8 @@ from core.app.entities.queue_entities import ( | |||
| QueueRetrieverResourcesEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AnnotationReply, | |||
| AnnotationReplyAccount, | |||
| EasyUITaskState, | |||
| MessageFileStreamResponse, | |||
| MessageReplaceStreamResponse, | |||
| @@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile | |||
| from services.annotation_service import AppAnnotationService | |||
| class MessageCycleManage: | |||
| class MessageCycleManager: | |||
| def __init__( | |||
| self, | |||
| *, | |||
| @@ -45,7 +47,7 @@ class MessageCycleManage: | |||
| self._application_generate_entity = application_generate_entity | |||
| self._task_state = task_state | |||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| """ | |||
| Generate conversation name. | |||
| :param conversation_id: conversation id | |||
| @@ -102,7 +104,7 @@ class MessageCycleManage: | |||
| db.session.commit() | |||
| db.session.close() | |||
| def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||
| def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Handle annotation reply. | |||
| :param event: event | |||
| @@ -111,25 +113,28 @@ class MessageCycleManage: | |||
| annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | |||
| if annotation: | |||
| account = annotation.account | |||
| self._task_state.metadata["annotation_reply"] = { | |||
| "id": annotation.id, | |||
| "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, | |||
| } | |||
| self._task_state.metadata.annotation_reply = AnnotationReply( | |||
| id=annotation.id, | |||
| account=AnnotationReplyAccount( | |||
| id=annotation.account_id, | |||
| name=account.name if account else "Dify user", | |||
| ), | |||
| ) | |||
| return annotation | |||
| return None | |||
| def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| """ | |||
| Handle retriever resources. | |||
| :param event: event | |||
| :return: | |||
| """ | |||
| if self._application_generate_entity.app_config.additional_features.show_retrieve_source: | |||
| self._task_state.metadata["retriever_resources"] = event.retriever_resources | |||
| self._task_state.metadata.retriever_resources = event.retriever_resources | |||
| def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||
| def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||
| """ | |||
| Message file to stream response. | |||
| :param event: event | |||
| @@ -166,7 +171,7 @@ class MessageCycleManage: | |||
| return None | |||
| def _message_to_stream_response( | |||
| def message_to_stream_response( | |||
| self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None | |||
| ) -> MessageStreamResponse: | |||
| """ | |||
| @@ -182,7 +187,7 @@ class MessageCycleManage: | |||
| from_variable_selector=from_variable_selector, | |||
| ) | |||
| def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||
| def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||
| """ | |||
| Message replace to stream response. | |||
| :param answer: answer | |||
| @@ -1,8 +1,10 @@ | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| @@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.commit() | |||
| def return_retriever_resource_info(self, resource: list): | |||
| # TODO(-LAN-): Improve type check | |||
| def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): | |||
| """Handle return_retriever_resource_info.""" | |||
| self._queue_manager.publish( | |||
| QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER | |||
| @@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat | |||
| from core.helper.code_executor.template_transformer import TemplateTransformer | |||
| logger = logging.getLogger(__name__) | |||
| code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) | |||
| class CodeExecutionError(Exception): | |||
| @@ -64,7 +65,7 @@ class CodeExecutor: | |||
| :param code: code | |||
| :return: | |||
| """ | |||
| url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" | |||
| url = code_execution_endpoint_url / "v1" / "sandbox" / "run" | |||
| headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} | |||
| @@ -7,29 +7,28 @@ from configs import dify_config | |||
| from core.helper.download import download_with_size_limit | |||
| from core.plugin.entities.marketplace import MarketplacePluginDeclaration | |||
| marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL)) | |||
| def get_plugin_pkg_url(plugin_unique_identifier: str): | |||
| return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query( | |||
| unique_identifier=plugin_unique_identifier | |||
| ) | |||
| def get_plugin_pkg_url(plugin_unique_identifier: str) -> str: | |||
| return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier)) | |||
| def download_plugin_pkg(plugin_unique_identifier: str): | |||
| url = str(get_plugin_pkg_url(plugin_unique_identifier)) | |||
| return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE) | |||
| return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE) | |||
| def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: | |||
| if len(plugin_ids) == 0: | |||
| return [] | |||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch") | |||
| url = str(marketplace_api_url / "api/v1/plugins/batch") | |||
| response = requests.post(url, json={"plugin_ids": plugin_ids}) | |||
| response.raise_for_status() | |||
| return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] | |||
| def record_install_plugin_event(plugin_unique_identifier: str): | |||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count") | |||
| url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") | |||
| response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) | |||
| response.raise_for_status() | |||
| @@ -1,61 +1,20 @@ | |||
| # Written by YORKI MINAKO🤡, Edited by Xiaoyi | |||
| CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. | |||
| Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc. | |||
| ENSURE your output is in the SAME language as the user's input! | |||
| Your output is restricted only to: (Input language) Intention + Subject(short as possible) | |||
| Your output MUST be a valid JSON. | |||
| # Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh | |||
| CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. | |||
| Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun. | |||
| 1. Detect Input Language | |||
| Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.). | |||
| 2. Generate Title | |||
| - Combine Intention + Subject into a single, as-short-as-possible phrase. | |||
| - The title must be natural, friendly, and in the same language as the input. | |||
| - If the input is a direct question to the model, you may add an emoji at the end. | |||
| example 1: | |||
| User Input: hi, yesterday i had some burgers. | |||
| 3. Output Format | |||
| Return **only** a valid JSON object with these exact keys and no additional text: | |||
| { | |||
| "Language Type": "The user's input is pure English", | |||
| "Your Reasoning": "The language of my output must be pure English.", | |||
| "Your Output": "sharing yesterday's food" | |||
| } | |||
| example 2: | |||
| User Input: hello | |||
| { | |||
| "Language Type": "The user's input is pure English", | |||
| "Your Reasoning": "The language of my output must be pure English.", | |||
| "Your Output": "Greeting myself☺️" | |||
| } | |||
| example 3: | |||
| User Input: why mmap file: oom | |||
| { | |||
| "Language Type": "The user's input is written in pure English", | |||
| "Your Reasoning": "The language of my output must be pure English.", | |||
| "Your Output": "Asking about the reason for mmap file: oom" | |||
| } | |||
| example 4: | |||
| User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么? | |||
| { | |||
| "Language Type": "The user's input English-Chinese mixed", | |||
| "Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.", | |||
| "Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv" | |||
| } | |||
| example 5: | |||
| User Input: why小红的年龄is老than小明? | |||
| { | |||
| "Language Type": "The user's input is English-Chinese mixed", | |||
| "Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.", | |||
| "Your Output": "询问小红和小明的年龄" | |||
| } | |||
| example 6: | |||
| User Input: yo, 你今天咋样? | |||
| { | |||
| "Language Type": "The user's input is English-Chinese mixed", | |||
| "Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.", | |||
| "Your Output": "查询今日我的状态☺️" | |||
| "Language Type": "<Detected language>", | |||
| "Your Reasoning": "<Brief explanation in that language>", | |||
| "Your Output": "<Intention + Subject>" | |||
| } | |||
| User Input: | |||
| @@ -17,19 +17,6 @@ class LLMMode(StrEnum): | |||
| COMPLETION = "completion" | |||
| CHAT = "chat" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "LLMMode": | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f"invalid mode value {value}") | |||
| class LLMUsage(ModelUsage): | |||
| """ | |||
| @@ -129,17 +129,18 @@ def jsonable_encoder( | |||
| sqlalchemy_safe=sqlalchemy_safe, | |||
| ) | |||
| if dataclasses.is_dataclass(obj): | |||
| # FIXME: mypy error, try to fix it instead of using type: ignore | |||
| obj_dict = dataclasses.asdict(obj) # type: ignore | |||
| return jsonable_encoder( | |||
| obj_dict, | |||
| by_alias=by_alias, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| custom_encoder=custom_encoder, | |||
| sqlalchemy_safe=sqlalchemy_safe, | |||
| ) | |||
| # Ensure obj is a dataclass instance, not a dataclass type | |||
| if not isinstance(obj, type): | |||
| obj_dict = dataclasses.asdict(obj) | |||
| return jsonable_encoder( | |||
| obj_dict, | |||
| by_alias=by_alias, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| custom_encoder=custom_encoder, | |||
| sqlalchemy_safe=sqlalchemy_safe, | |||
| ) | |||
| if isinstance(obj, Enum): | |||
| return obj.value | |||
| if isinstance(obj, PurePath): | |||
| @@ -1,7 +1,11 @@ | |||
| from abc import ABC, abstractmethod | |||
| from sqlalchemy.orm import Session | |||
| from core.ops.entities.config_entity import BaseTracingConfig | |||
| from core.ops.entities.trace_entity import BaseTraceInfo | |||
| from extensions.ext_database import db | |||
| from models import Account, App, TenantAccountJoin | |||
| class BaseTraceInstance(ABC): | |||
| @@ -24,3 +28,38 @@ class BaseTraceInstance(ABC): | |||
| Subclasses must implement specific tracing logic for activities. | |||
| """ | |||
| ... | |||
| def get_service_account_with_tenant(self, app_id: str) -> Account: | |||
| """ | |||
| Get service account for an app and set up its tenant. | |||
| Args: | |||
| app_id: The ID of the app | |||
| Returns: | |||
| Account: The service account with tenant set up | |||
| Raises: | |||
| ValueError: If app, creator account or tenant cannot be found | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| current_tenant = ( | |||
| session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() | |||
| ) | |||
| if not current_tenant: | |||
| raise ValueError(f"Current tenant not found for account {service_account.id}") | |||
| service_account.set_tenant_id(current_tenant.tenant_id) | |||
| return service_account | |||
| @@ -3,7 +3,7 @@ from datetime import datetime | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel, ConfigDict, field_validator | |||
| from pydantic import BaseModel, ConfigDict, field_serializer, field_validator | |||
| class BaseTraceInfo(BaseModel): | |||
| @@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel): | |||
| return v | |||
| return "" | |||
| class Config: | |||
| json_encoders = { | |||
| datetime: lambda v: v.isoformat(), | |||
| } | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @field_serializer("start_time", "end_time") | |||
| def serialize_datetime(self, dt: datetime | None) -> str | None: | |||
| if dt is None: | |||
| return None | |||
| return dt.isoformat() | |||
| class WorkflowTraceInfo(BaseTraceInfo): | |||
| @@ -4,7 +4,7 @@ from datetime import datetime, timedelta | |||
| from typing import Optional | |||
| from langfuse import Langfuse # type: ignore | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import LangfuseConfig | |||
| @@ -31,7 +31,7 @@ from core.ops.utils import filter_none_values | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom | |||
| from models import EndUser, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -114,22 +114,11 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| @@ -6,7 +6,7 @@ from typing import Optional, cast | |||
| from langsmith import Client | |||
| from langsmith.schemas import RunBase | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import LangSmithConfig | |||
| @@ -28,10 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( | |||
| ) | |||
| from core.ops.utils import filter_none_values, generate_dotted_order | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -139,22 +139,11 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| @@ -185,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 | |||
| metadata = {str(key): value for key, value in execution_metadata.items()} | |||
| metadata.update( | |||
| { | |||
| @@ -6,7 +6,7 @@ from typing import Optional, cast | |||
| from opik import Opik, Trace | |||
| from opik.id_helpers import uuid4_to_uuid7 | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import OpikConfig | |||
| @@ -22,10 +22,10 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -154,22 +154,11 @@ class OpikDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| @@ -246,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id | |||
| if not total_tokens: | |||
| total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 | |||
| span_data = { | |||
| "trace_id": opik_trace_id, | |||
| @@ -30,7 +30,7 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.ops.utils import get_message_data | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig | |||
| @@ -386,7 +386,7 @@ class TraceTask: | |||
| ): | |||
| self.trace_type = trace_type | |||
| self.message_id = message_id | |||
| self.workflow_run_id = workflow_execution.id if workflow_execution else None | |||
| self.workflow_run_id = workflow_execution.id_ if workflow_execution else None | |||
| self.conversation_id = conversation_id | |||
| self.user_id = user_id | |||
| self.timer = timer | |||
| @@ -487,6 +487,7 @@ class TraceTask: | |||
| "file_list": file_list, | |||
| "triggered_from": workflow_run.triggered_from, | |||
| "user_id": user_id, | |||
| "app_id": workflow_run.app_id, | |||
| } | |||
| workflow_trace_info = WorkflowTraceInfo( | |||
| @@ -6,7 +6,7 @@ from typing import Any, Optional, cast | |||
| import wandb | |||
| import weave | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import WeaveConfig | |||
| @@ -23,10 +23,10 @@ from core.ops.entities.trace_entity import ( | |||
| ) | |||
| from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -133,22 +133,11 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| @@ -179,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 | |||
| attributes = {str(k): v for k, v in execution_metadata.items()} | |||
| attributes.update( | |||
| { | |||
| @@ -31,8 +31,7 @@ from core.plugin.impl.exc import ( | |||
| PluginUniqueIdentifierError, | |||
| ) | |||
| plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL | |||
| plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY | |||
| plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) | |||
| T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) | |||
| @@ -53,9 +52,9 @@ class BasePluginClient: | |||
| """ | |||
| Make a request to the plugin daemon inner API. | |||
| """ | |||
| url = URL(str(plugin_daemon_inner_api_baseurl)) / path | |||
| url = plugin_daemon_inner_api_baseurl / path | |||
| headers = headers or {} | |||
| headers["X-Api-Key"] = plugin_daemon_inner_api_key | |||
| headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY | |||
| headers["Accept-Encoding"] = "gzip, deflate, br" | |||
| if headers.get("Content-Type") == "application/json" and isinstance(data, dict): | |||
| @@ -85,7 +85,6 @@ class BaiduVector(BaseVector): | |||
| end = min(start + batch_size, total_count) | |||
| rows = [] | |||
| assert len(metadatas) == total_count, "metadatas length should be equal to total_count" | |||
| # FIXME do you need this assert? | |||
| for i in range(start, end, 1): | |||
| row = Row( | |||
| id=metadatas[i].get("doc_id", str(uuid.uuid4())), | |||
| @@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector): | |||
| if score > score_threshold: | |||
| if doc.metadata is not None: | |||
| doc.metadata["score"] = score | |||
| docs.append(doc) | |||
| docs.append(doc) | |||
| return docs | |||
| @@ -97,6 +97,10 @@ class MilvusVector(BaseVector): | |||
| try: | |||
| milvus_version = self._client.get_server_version() | |||
| # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility | |||
| if "Zilliz Cloud" in milvus_version: | |||
| return True | |||
| # For standard Milvus installations, check version number | |||
| return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version | |||
| except Exception as e: | |||
| logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") | |||
| @@ -245,4 +245,4 @@ class TidbService: | |||
| return cluster_infos | |||
| else: | |||
| response.raise_for_status() | |||
| return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception | |||
| return [] | |||
| @@ -0,0 +1,23 @@ | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| class RetrievalSourceMetadata(BaseModel): | |||
| position: Optional[int] = None | |||
| dataset_id: Optional[str] = None | |||
| dataset_name: Optional[str] = None | |||
| document_id: Optional[str] = None | |||
| document_name: Optional[str] = None | |||
| data_source_type: Optional[str] = None | |||
| segment_id: Optional[str] = None | |||
| retriever_from: Optional[str] = None | |||
| score: Optional[float] = None | |||
| hit_count: Optional[int] = None | |||
| word_count: Optional[int] = None | |||
| segment_position: Optional[int] = None | |||
| index_node_hash: Optional[str] = None | |||
| content: Optional[str] = None | |||
| page: Optional[int] = None | |||
| doc_metadata: Optional[dict[str, Any]] = None | |||
| title: Optional[str] = None | |||
| @@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel): | |||
| website import info. | |||
| """ | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| provider: str | |||
| job_id: str | |||
| url: str | |||
| @@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel): | |||
| tenant_id: str | |||
| only_main_content: bool = False | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| class ExtractSetting(BaseModel): | |||
| """ | |||
| @@ -70,13 +70,12 @@ class BaseDocumentTransformer(ABC): | |||
| .. code-block:: python | |||
| class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| embeddings: Embeddings | |||
| similarity_fn: Callable = cosine_similarity | |||
| similarity_threshold: float = 0.95 | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def transform_documents( | |||
| self, documents: Sequence[Document], **kwargs: Any | |||
| ) -> Sequence[Document]: | |||
| @@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| @@ -198,21 +199,21 @@ class DatasetRetrieval: | |||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | |||
| external_documents = [item for item in all_documents if item.provider == "external"] | |||
| document_context_list = [] | |||
| retrieval_resource_list = [] | |||
| document_context_list: list[DocumentContext] = [] | |||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||
| # deal with external documents | |||
| for item in external_documents: | |||
| document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) | |||
| source = { | |||
| "dataset_id": item.metadata.get("dataset_id"), | |||
| "dataset_name": item.metadata.get("dataset_name"), | |||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||
| "document_name": item.metadata.get("title"), | |||
| "data_source_type": "external", | |||
| "retriever_from": invoke_from.to_source(), | |||
| "score": item.metadata.get("score"), | |||
| "content": item.page_content, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=item.metadata.get("dataset_id"), | |||
| dataset_name=item.metadata.get("dataset_name"), | |||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||
| document_name=item.metadata.get("title"), | |||
| data_source_type="external", | |||
| retriever_from=invoke_from.to_source(), | |||
| score=item.metadata.get("score"), | |||
| content=item.page_content, | |||
| ) | |||
| retrieval_resource_list.append(source) | |||
| # deal with dify documents | |||
| if dify_documents: | |||
| @@ -248,32 +249,32 @@ class DatasetRetrieval: | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": invoke_from.to_source(), | |||
| "score": record.score or 0.0, | |||
| "doc_metadata": document.doc_metadata, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, | |||
| document_name=document.name, | |||
| data_source_type=document.data_source_type, | |||
| segment_id=segment.id, | |||
| retriever_from=invoke_from.to_source(), | |||
| score=record.score or 0.0, | |||
| doc_metadata=document.doc_metadata, | |||
| ) | |||
| if invoke_from.to_source() == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| retrieval_resource_list.append(source) | |||
| if hit_callback and retrieval_resource_list: | |||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) | |||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): | |||
| item["position"] = position | |||
| item.position = position | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| @@ -936,6 +937,9 @@ class DatasetRetrieval: | |||
| return metadata_filter_document_ids, metadata_condition | |||
| def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: | |||
| if not inputs: | |||
| return text | |||
| def replacer(match): | |||
| key = match.group(1) | |||
| return str(inputs.get(key, f"{{{{{key}}}}}")) | |||
| @@ -10,12 +10,12 @@ from sqlalchemy import select | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_execution_entities import ( | |||
| from core.workflow.entities.workflow_execution import ( | |||
| WorkflowExecution, | |||
| WorkflowExecutionStatus, | |||
| WorkflowType, | |||
| ) | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| @@ -104,10 +104,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): | |||
| status = WorkflowExecutionStatus(db_model.status) | |||
| return WorkflowExecution( | |||
| id=db_model.id, | |||
| id_=db_model.id, | |||
| workflow_id=db_model.workflow_id, | |||
| sequence_number=db_model.sequence_number, | |||
| type=WorkflowType(db_model.type), | |||
| workflow_type=WorkflowType(db_model.type), | |||
| workflow_version=db_model.version, | |||
| graph=graph, | |||
| inputs=inputs, | |||
| @@ -140,14 +139,29 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): | |||
| raise ValueError("created_by_role is required in repository constructor") | |||
| db_model = WorkflowRun() | |||
| db_model.id = domain_model.id | |||
| db_model.id = domain_model.id_ | |||
| db_model.tenant_id = self._tenant_id | |||
| if self._app_id is not None: | |||
| db_model.app_id = self._app_id | |||
| db_model.workflow_id = domain_model.workflow_id | |||
| db_model.triggered_from = self._triggered_from | |||
| db_model.sequence_number = domain_model.sequence_number | |||
| db_model.type = domain_model.type | |||
| # Check if this is a new record | |||
| with self._session_factory() as session: | |||
| existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_)) | |||
| if not existing: | |||
| # For new records, get the next sequence number | |||
| stmt = select(WorkflowRun.sequence_number).where( | |||
| WorkflowRun.app_id == self._app_id, | |||
| WorkflowRun.tenant_id == self._tenant_id, | |||
| ) | |||
| max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc())) | |||
| db_model.sequence_number = (max_sequence or 0) + 1 | |||
| else: | |||
| # For updates, keep the existing sequence number | |||
| db_model.sequence_number = existing.sequence_number | |||
| db_model.type = domain_model.workflow_type | |||
| db_model.version = domain_model.workflow_version | |||
| db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None | |||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||
| @@ -12,19 +12,18 @@ from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| from core.workflow.entities.workflow_node_execution import ( | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionMetadataKey, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| EndUser, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowNodeExecutionModel, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| ) | |||
| @@ -87,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| # Initialize in-memory cache for node executions | |||
| # Key: node_execution_id, Value: WorkflowNodeExecution (DB model) | |||
| self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} | |||
| self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {} | |||
| def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution: | |||
| def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution: | |||
| """ | |||
| Convert a database model to a domain model. | |||
| @@ -103,16 +102,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| inputs = db_model.inputs_dict | |||
| process_data = db_model.process_data_dict | |||
| outputs = db_model.outputs_dict | |||
| metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()} | |||
| metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()} | |||
| # Convert status to domain enum | |||
| status = NodeExecutionStatus(db_model.status) | |||
| status = WorkflowNodeExecutionStatus(db_model.status) | |||
| return NodeExecution( | |||
| return WorkflowNodeExecution( | |||
| id=db_model.id, | |||
| node_execution_id=db_model.node_execution_id, | |||
| workflow_id=db_model.workflow_id, | |||
| workflow_run_id=db_model.workflow_run_id, | |||
| workflow_execution_id=db_model.workflow_run_id, | |||
| index=db_model.index, | |||
| predecessor_node_id=db_model.predecessor_node_id, | |||
| node_id=db_model.node_id, | |||
| @@ -129,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| finished_at=db_model.finished_at, | |||
| ) | |||
| def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution: | |||
| def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel: | |||
| """ | |||
| Convert a domain model to a database model. | |||
| @@ -147,14 +146,14 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| if not self._creator_user_role: | |||
| raise ValueError("created_by_role is required in repository constructor") | |||
| db_model = WorkflowNodeExecution() | |||
| db_model = WorkflowNodeExecutionModel() | |||
| db_model.id = domain_model.id | |||
| db_model.tenant_id = self._tenant_id | |||
| if self._app_id is not None: | |||
| db_model.app_id = self._app_id | |||
| db_model.workflow_id = domain_model.workflow_id | |||
| db_model.triggered_from = self._triggered_from | |||
| db_model.workflow_run_id = domain_model.workflow_run_id | |||
| db_model.workflow_run_id = domain_model.workflow_execution_id | |||
| db_model.index = domain_model.index | |||
| db_model.predecessor_node_id = domain_model.predecessor_node_id | |||
| db_model.node_execution_id = domain_model.node_execution_id | |||
| @@ -176,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| db_model.finished_at = domain_model.finished_at | |||
| return db_model | |||
| def save(self, execution: NodeExecution) -> None: | |||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||
| """ | |||
| Save or update a NodeExecution domain entity to the database. | |||
| @@ -208,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") | |||
| self._node_execution_cache[db_model.node_execution_id] = db_model | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve a NodeExecution by its node_execution_id. | |||
| @@ -231,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| # If not in cache, query the database | |||
| logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| WorkflowNodeExecution.node_execution_id == node_execution_id, | |||
| WorkflowNodeExecution.tenant_id == self._tenant_id, | |||
| stmt = select(WorkflowNodeExecutionModel).where( | |||
| WorkflowNodeExecutionModel.node_execution_id == node_execution_id, | |||
| WorkflowNodeExecutionModel.tenant_id == self._tenant_id, | |||
| ) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) | |||
| db_model = session.scalar(stmt) | |||
| if db_model: | |||
| @@ -254,7 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| ) -> Sequence[WorkflowNodeExecutionModel]: | |||
| """ | |||
| Retrieve all WorkflowNodeExecution database models for a specific workflow run. | |||
| @@ -272,20 +271,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| A list of WorkflowNodeExecution database models | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run_id, | |||
| WorkflowNodeExecution.tenant_id == self._tenant_id, | |||
| WorkflowNodeExecution.triggered_from == triggered_from, | |||
| stmt = select(WorkflowNodeExecutionModel).where( | |||
| WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, | |||
| WorkflowNodeExecutionModel.tenant_id == self._tenant_id, | |||
| WorkflowNodeExecutionModel.triggered_from == triggered_from, | |||
| ) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) | |||
| # Apply ordering if provided | |||
| if order_config and order_config.order_by: | |||
| order_columns: list[UnaryExpression] = [] | |||
| for field in order_config.order_by: | |||
| column = getattr(WorkflowNodeExecution, field, None) | |||
| column = getattr(WorkflowNodeExecutionModel, field, None) | |||
| if not column: | |||
| continue | |||
| if order_config.order_direction == "desc": | |||
| @@ -310,7 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) -> Sequence[NodeExecution]: | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all NodeExecution instances for a specific workflow run. | |||
| @@ -337,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| return domain_models | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||
| @@ -351,15 +350,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| A list of running NodeExecution instances | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run_id, | |||
| WorkflowNodeExecution.tenant_id == self._tenant_id, | |||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| stmt = select(WorkflowNodeExecutionModel).where( | |||
| WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, | |||
| WorkflowNodeExecutionModel.tenant_id == self._tenant_id, | |||
| WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, | |||
| WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) | |||
| db_models = session.scalars(stmt).all() | |||
| domain_models = [] | |||
| @@ -384,10 +383,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| It also clears the in-memory cache. | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) | |||
| stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) | |||
| result = session.execute(stmt) | |||
| session.commit() | |||
| @@ -168,7 +168,7 @@ class ApiTool(Tool): | |||
| cookies[parameter["name"]] = value | |||
| elif parameter["in"] == "header": | |||
| headers[parameter["name"]] = value | |||
| headers[parameter["name"]] = str(value) | |||
| # check if there is a request body and handle it | |||
| if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: | |||
| @@ -279,7 +279,6 @@ class ToolParameter(PluginParameter): | |||
| :param options: the options of the parameter | |||
| """ | |||
| # convert options to ToolParameterOption | |||
| # FIXME fix the type error | |||
| if options: | |||
| option_objs = [ | |||
| PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) | |||
| @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.models.document import Document as RagDocument | |||
| from core.rag.rerank.rerank_model import RerankModelRunner | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| @@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| else: | |||
| document_context_list.append(segment.get_sign_content()) | |||
| if self.return_resource: | |||
| context_list = [] | |||
| context_list: list[RetrievalSourceMetadata] = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| @@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "position": resource_number, | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": document_score_list.get(segment.index_node_id, None), | |||
| "doc_metadata": document.doc_metadata, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=resource_number, | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, | |||
| document_name=document.name, | |||
| data_source_type=document.data_source_type, | |||
| segment_id=segment.id, | |||
| retriever_from=self.retriever_from, | |||
| score=document_score_list.get(segment.index_node_id, None), | |||
| doc_metadata=document.doc_metadata, | |||
| ) | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| @@ -14,7 +15,7 @@ from models.dataset import Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| default_retrieval_model = { | |||
| default_retrieval_model: dict[str, Any] = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| @@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| else: | |||
| document_ids_filter = None | |||
| if dataset.provider == "external": | |||
| results = [] | |||
| results: list[RetrievalDocument] = [] | |||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| @@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| document.metadata["dataset_name"] = dataset.name | |||
| results.append(document) | |||
| # deal with external documents | |||
| context_list = [] | |||
| context_list: list[RetrievalSourceMetadata] = [] | |||
| for position, item in enumerate(results, start=1): | |||
| if item.metadata is not None: | |||
| source = { | |||
| "position": position, | |||
| "dataset_id": item.metadata.get("dataset_id"), | |||
| "dataset_name": item.metadata.get("dataset_name"), | |||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||
| "document_name": item.metadata.get("title"), | |||
| "data_source_type": "external", | |||
| "retriever_from": self.retriever_from, | |||
| "score": item.metadata.get("score"), | |||
| "title": item.metadata.get("title"), | |||
| "content": item.page_content, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=position, | |||
| dataset_id=item.metadata.get("dataset_id"), | |||
| dataset_name=item.metadata.get("dataset_name"), | |||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||
| document_name=item.metadata.get("title"), | |||
| data_source_type="external", | |||
| retriever_from=self.retriever_from, | |||
| score=item.metadata.get("score"), | |||
| title=item.metadata.get("title"), | |||
| content=item.page_content, | |||
| ) | |||
| context_list.append(source) | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| @@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| return "" | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | |||
| retrieval_resource_list = [] | |||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| @@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| for item in documents: | |||
| if item.metadata is not None and item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| document_context_list = [] | |||
| document_context_list: list[DocumentContext] = [] | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| if records: | |||
| for record in records: | |||
| @@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, # type: ignore | |||
| "document_name": document.name, # type: ignore | |||
| "data_source_type": document.data_source_type, # type: ignore | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": record.score or 0.0, | |||
| "doc_metadata": document.doc_metadata, # type: ignore | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, # type: ignore | |||
| document_name=document.name, # type: ignore | |||
| data_source_type=document.data_source_type, # type: ignore | |||
| segment_id=segment.id, | |||
| retriever_from=self.retriever_from, | |||
| score=record.score or 0.0, | |||
| doc_metadata=document.doc_metadata, # type: ignore | |||
| ) | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| retrieval_resource_list.append(source) | |||
| if self.return_resource and retrieval_resource_list: | |||
| retrieval_resource_list = sorted( | |||
| retrieval_resource_list, | |||
| key=lambda x: x.get("score") or 0.0, | |||
| key=lambda x: x.score or 0.0, | |||
| reverse=True, | |||
| ) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||
| item["position"] = position # type: ignore | |||
| item.position = position # type: ignore | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| @@ -66,7 +66,6 @@ class ToolFileMessageTransformer: | |||
| if not isinstance(message.message, ToolInvokeMessage.BlobMessage): | |||
| raise ValueError("unexpected message type") | |||
| # FIXME: should do a type check here. | |||
| assert isinstance(message.message.blob, bytes) | |||
| tool_file_manager = ToolFileManager() | |||
| file = tool_file_manager.create_file_by_raw( | |||
| @@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser: | |||
| # convert parameters | |||
| parameters = [] | |||
| if "parameters" in interface["operation"]: | |||
| for i, parameter in enumerate(interface["operation"]["parameters"]): | |||
| if "$ref" in parameter: | |||
| root = openapi | |||
| reference = parameter["$ref"].split("/")[1:] | |||
| for ref in reference: | |||
| root = root[ref] | |||
| interface["operation"]["parameters"][i] = root | |||
| for parameter in interface["operation"]["parameters"]: | |||
| tool_parameter = ToolParameter( | |||
| name=parameter["name"], | |||
| @@ -1,37 +1,10 @@ | |||
| from collections.abc import Mapping | |||
| from enum import StrEnum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class NodeRunMetadataKey(StrEnum): | |||
| """ | |||
| Node Run Metadata Key. | |||
| """ | |||
| TOTAL_TOKENS = "total_tokens" | |||
| TOTAL_PRICE = "total_price" | |||
| CURRENCY = "currency" | |||
| TOOL_INFO = "tool_info" | |||
| DATASOURCE_INFO = "datasource_info" | |||
| AGENT_LOG = "agent_log" | |||
| ITERATION_ID = "iteration_id" | |||
| ITERATION_INDEX = "iteration_index" | |||
| LOOP_ID = "loop_id" | |||
| LOOP_INDEX = "loop_index" | |||
| PARALLEL_ID = "parallel_id" | |||
| PARALLEL_START_NODE_ID = "parallel_start_node_id" | |||
| PARENT_PARALLEL_ID = "parent_parallel_id" | |||
| PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" | |||
| PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" | |||
| ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs | |||
| LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs | |||
| ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field | |||
| LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| class NodeRunResult(BaseModel): | |||
| @@ -44,7 +17,7 @@ class NodeRunResult(BaseModel): | |||
| inputs: Optional[Mapping[str, Any]] = None # node inputs | |||
| process_data: Optional[Mapping[str, Any]] = None # process data | |||
| outputs: Optional[Mapping[str, Any]] = None # node outputs | |||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata | |||
| metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata | |||
| llm_usage: Optional[LLMUsage] = None # llm usage | |||
| edge_source_handle: Optional[str] = None # source handle id of node with multiple branches | |||
| @@ -37,12 +37,10 @@ class WorkflowExecution(BaseModel): | |||
| user, tenant, and app attributes. | |||
| """ | |||
| id: str = Field(...) | |||
| id_: str = Field(...) | |||
| workflow_id: str = Field(...) | |||
| workflow_version: str = Field(...) | |||
| sequence_number: int = Field(...) | |||
| type: WorkflowType = Field(...) | |||
| workflow_type: WorkflowType = Field(...) | |||
| graph: Mapping[str, Any] = Field(...) | |||
| inputs: Mapping[str, Any] = Field(...) | |||
| @@ -70,20 +68,18 @@ class WorkflowExecution(BaseModel): | |||
| def new( | |||
| cls, | |||
| *, | |||
| id: str, | |||
| id_: str, | |||
| workflow_id: str, | |||
| sequence_number: int, | |||
| type: WorkflowType, | |||
| workflow_type: WorkflowType, | |||
| workflow_version: str, | |||
| graph: Mapping[str, Any], | |||
| inputs: Mapping[str, Any], | |||
| started_at: datetime, | |||
| ) -> "WorkflowExecution": | |||
| return WorkflowExecution( | |||
| id=id, | |||
| id_=id_, | |||
| workflow_id=workflow_id, | |||
| sequence_number=sequence_number, | |||
| type=type, | |||
| workflow_type=workflow_type, | |||
| workflow_version=workflow_version, | |||
| graph=graph, | |||
| inputs=inputs, | |||
| @@ -13,11 +13,35 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| class NodeExecutionStatus(StrEnum): | |||
| class WorkflowNodeExecutionMetadataKey(StrEnum): | |||
| """ | |||
| Node Run Metadata Key. | |||
| """ | |||
| TOTAL_TOKENS = "total_tokens" | |||
| TOTAL_PRICE = "total_price" | |||
| CURRENCY = "currency" | |||
| TOOL_INFO = "tool_info" | |||
| AGENT_LOG = "agent_log" | |||
| ITERATION_ID = "iteration_id" | |||
| ITERATION_INDEX = "iteration_index" | |||
| LOOP_ID = "loop_id" | |||
| LOOP_INDEX = "loop_index" | |||
| PARALLEL_ID = "parallel_id" | |||
| PARALLEL_START_NODE_ID = "parallel_start_node_id" | |||
| PARENT_PARALLEL_ID = "parent_parallel_id" | |||
| PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" | |||
| PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" | |||
| ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs | |||
| LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs | |||
| ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field | |||
| LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output | |||
| class WorkflowNodeExecutionStatus(StrEnum): | |||
| """ | |||
| Node Execution Status Enum. | |||
| """ | |||
| @@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum): | |||
| RETRY = "retry" | |||
| class NodeExecution(BaseModel): | |||
| class WorkflowNodeExecution(BaseModel): | |||
| """ | |||
| Domain model for workflow node execution. | |||
| @@ -46,7 +70,7 @@ class NodeExecution(BaseModel): | |||
| id: str # Unique identifier for this execution record | |||
| node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing | |||
| workflow_id: str # ID of the workflow this node belongs to | |||
| workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) | |||
| workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) | |||
| # Execution positioning and flow | |||
| index: int # Sequence number for ordering in trace visualization | |||
| @@ -61,12 +85,12 @@ class NodeExecution(BaseModel): | |||
| outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node | |||
| # Execution state | |||
| status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status | |||
| status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status | |||
| error: Optional[str] = None # Error message if execution failed | |||
| elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds | |||
| # Additional metadata | |||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) | |||
| metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) | |||
| # Timing information | |||
| created_at: datetime # When execution started | |||
| @@ -77,7 +101,7 @@ class NodeExecution(BaseModel): | |||
| inputs: Optional[Mapping[str, Any]] = None, | |||
| process_data: Optional[Mapping[str, Any]] = None, | |||
| outputs: Optional[Mapping[str, Any]] = None, | |||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, | |||
| metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, | |||
| ) -> None: | |||
| """ | |||
| Update the model from mappings. | |||
| @@ -13,7 +13,7 @@ class SystemVariableKey(StrEnum): | |||
| DIALOGUE_COUNT = "dialogue_count" | |||
| APP_ID = "app_id" | |||
| WORKFLOW_ID = "workflow_id" | |||
| WORKFLOW_RUN_ID = "workflow_run_id" | |||
| WORKFLOW_EXECUTION_ID = "workflow_run_id" | |||
| # RAG Pipeline | |||
| DOCUMENT_ID = "document_id" | |||
| BATCH = "batch" | |||
| @@ -1,9 +1,10 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| from core.workflow.nodes import NodeType | |||
| @@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent): | |||
| class NodeRunRetrieverResourceEvent(BaseNodeEvent): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| @@ -6,7 +6,7 @@ from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| class RouteNodeState(BaseModel): | |||
| @@ -14,8 +14,9 @@ from flask import Flask, current_app, has_request_context | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseAgentEvent, | |||
| @@ -52,9 +53,8 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | |||
| from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from extensions.ext_database import db | |||
| from models.enums import UserFrom | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| from models.workflow import WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| @@ -606,8 +606,6 @@ class GraphEngine: | |||
| error=str(e), | |||
| ) | |||
| ) | |||
| finally: | |||
| db.session.remove() | |||
| def _run_node( | |||
| self, | |||
| @@ -645,7 +643,6 @@ class GraphEngine: | |||
| agent_strategy=agent_strategy, | |||
| ) | |||
| db.session.close() | |||
| max_retries = node_instance.node_data.retry_config.max_retries | |||
| retry_interval = node_instance.node_data.retry_config.retry_interval_seconds | |||
| retries = 0 | |||
| @@ -759,10 +756,12 @@ class GraphEngine: | |||
| and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH | |||
| ): | |||
| run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS | |||
| if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| if run_result.metadata and run_result.metadata.get( | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS | |||
| ): | |||
| # plus state total_tokens | |||
| self.graph_runtime_state.total_tokens += int( | |||
| run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] | |||
| run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] | |||
| ) | |||
| if run_result.llm_usage: | |||
| @@ -785,13 +784,17 @@ class GraphEngine: | |||
| if parallel_id and parallel_start_node_id: | |||
| metadata_dict = dict(run_result.metadata) | |||
| metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id | |||
| metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id | |||
| metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id | |||
| metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = ( | |||
| parallel_start_node_id | |||
| ) | |||
| if parent_parallel_id and parent_parallel_start_node_id: | |||
| metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id | |||
| metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( | |||
| parent_parallel_start_node_id | |||
| metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = ( | |||
| parent_parallel_id | |||
| ) | |||
| metadata_dict[ | |||
| WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID | |||
| ] = parent_parallel_start_node_id | |||
| run_result.metadata = metadata_dict | |||
| yield NodeRunSucceededEvent( | |||
| @@ -856,8 +859,6 @@ class GraphEngine: | |||
| except Exception as e: | |||
| logger.exception(f"Node {node_instance.node_data.title} run failed") | |||
| raise e | |||
| finally: | |||
| db.session.close() | |||
| def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||
| """ | |||
| @@ -923,7 +924,7 @@ class GraphEngine: | |||
| "error": error_result.error, | |||
| "inputs": error_result.inputs, | |||
| "metadata": { | |||
| NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, | |||
| WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, | |||
| }, | |||
| } | |||
| @@ -2,6 +2,9 @@ import json | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.agent.entities import AgentToolEntity | |||
| from core.agent.plugin_entities import AgentStrategyParameter | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| @@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager | |||
| from core.variables.segments import StringSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated | |||
| from core.workflow.nodes.base.entities import BaseNodeData | |||
| @@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from factories.agent_factory import get_plugin_agent_strategy | |||
| from models.model import Conversation | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class AgentNode(ToolNode): | |||
| @@ -320,15 +323,12 @@ class AgentNode(ToolNode): | |||
| return None | |||
| conversation_id = conversation_id_variable.value | |||
| # get conversation | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) | |||
| .first() | |||
| ) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) | |||
| conversation = session.scalar(stmt) | |||
| if not conversation: | |||
| return None | |||
| if not conversation: | |||
| return None | |||
| memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) | |||
| @@ -3,6 +3,7 @@ from typing import Any, cast | |||
| from core.variables import ArrayFileSegment, FileSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter | |||
| from core.workflow.nodes.answer.entities import ( | |||
| AnswerNodeData, | |||
| @@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import ( | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class AnswerNode(BaseNode[AnswerNodeData]): | |||
| @@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import BaseNodeData | |||
| @@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc | |||
| from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider | |||
| from core.variables.segments import ArrayFileSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.code.entities import CodeNodeData | |||
| from core.workflow.nodes.enums import NodeType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .exc import ( | |||
| CodeNodeError, | |||
| @@ -26,9 +26,9 @@ from core.helper import ssrf_proxy | |||
| from core.variables import ArrayFileSegment | |||
| from core.variables.segments import FileSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import DocumentExtractorNodeData | |||
| from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError | |||
| @@ -1,8 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.end.entities import EndNodeData | |||
| from core.workflow.nodes.enums import NodeType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class EndNode(BaseNode[EndNodeData]): | |||
| @@ -1,10 +1,12 @@ | |||
| from collections.abc import Sequence | |||
| from datetime import datetime | |||
| from pydantic import BaseModel, Field | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| class RunCompletedEvent(BaseModel): | |||
| @@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel): | |||
| class RunRetrieverResourceEvent(BaseModel): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| @@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.http_request.executor import Executor | |||
| from core.workflow.utils import variable_template_parser | |||
| from factories import file_factory | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import ( | |||
| HttpRequestNodeData, | |||
| @@ -4,12 +4,12 @@ from typing_extensions import deprecated | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.if_else.entities import IfElseNodeData | |||
| from core.workflow.utils.condition.entities import Condition | |||
| from core.workflow.utils.condition.processor import ConditionProcessor | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class IfElseNode(BaseNode[IfElseNodeData]): | |||
| @@ -12,10 +12,10 @@ from flask import Flask, current_app, has_request_context | |||
| from configs import dify_config | |||
| from core.variables import ArrayVariable, IntegerVariable, NoneVariable | |||
| from core.workflow.entities.node_entities import ( | |||
| NodeRunMetadataKey, | |||
| NodeRunResult, | |||
| ) | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseGraphEvent, | |||
| BaseNodeEvent, | |||
| @@ -37,7 +37,6 @@ from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .exc import ( | |||
| InvalidIteratorValueError, | |||
| @@ -249,8 +248,8 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"output": outputs}, | |||
| metadata={ | |||
| NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map, | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| }, | |||
| ) | |||
| ) | |||
| @@ -361,16 +360,16 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| event.parallel_mode_run_id = parallel_mode_run_id | |||
| iter_metadata = { | |||
| NodeRunMetadataKey.ITERATION_ID: self.node_id, | |||
| NodeRunMetadataKey.ITERATION_INDEX: iter_run_index, | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id, | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, | |||
| } | |||
| if parallel_mode_run_id: | |||
| # for parallel, the specific branch ID is more important than the sequential index | |||
| iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id | |||
| iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id | |||
| if event.route_node_state.node_run_result: | |||
| current_metadata = event.route_node_state.node_run_result.metadata or {} | |||
| if NodeRunMetadataKey.ITERATION_ID not in current_metadata: | |||
| if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: | |||
| event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata} | |||
| return event | |||
| @@ -1,8 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.iteration.entities import IterationStartNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class IterationStartNode(BaseNode[IterationStartNodeData]): | |||
| @@ -8,6 +8,7 @@ from typing import Any, Optional, cast | |||
| from sqlalchemy import Float, and_, func, or_, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from sqlalchemy.orm import Session | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| @@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables import StringSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event.event import ModelInvokeCompletedEvent | |||
| from core.workflow.nodes.knowledge_retrieval.template_prompts import ( | |||
| @@ -41,7 +43,6 @@ from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from services.feature_service import FeatureService | |||
| from .entities import KnowledgeRetrievalNodeData, ModelConfig | |||
| @@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||
| request_count = redis_client.zcard(key) | |||
| if request_count > knowledge_rate_limit.limit: | |||
| # add ratelimit record | |||
| rate_limit_log = RateLimitLog( | |||
| tenant_id=self.tenant_id, | |||
| subscription_plan=knowledge_rate_limit.subscription_plan, | |||
| operation="knowledge", | |||
| ) | |||
| db.session.add(rate_limit_log) | |||
| db.session.commit() | |||
| with Session(db.engine) as session: | |||
| # add ratelimit record | |||
| rate_limit_log = RateLimitLog( | |||
| tenant_id=self.tenant_id, | |||
| subscription_plan=knowledge_rate_limit.subscription_plan, | |||
| operation="knowledge", | |||
| ) | |||
| session.add(rate_limit_log) | |||
| session.commit() | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=variables, | |||
| @@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| dataset_retrieval = DatasetRetrieval() | |||
| if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore | |||
| if node_data.single_retrieval_config is None: | |||
| raise ValueError("single_retrieval_config is required") | |||
| model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model) | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| @@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| raise ValueError("metadata_model_config is required") | |||
| # get metadata model instance | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore | |||
| model_instance, model_config = self.get_model_config(metadata_model_config) | |||
| # fetch prompt messages | |||
| prompt_template = self._get_prompt_template( | |||
| node_data=node_data, | |||
| @@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| variable_mapping[node_id + ".query"] = node_data.query_variable_selector | |||
| return variable_mapping | |||
| def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore | |||
| """ | |||
| Fetch model config | |||
| :param model: model | |||
| :return: | |||
| """ | |||
| if model is None: | |||
| raise ValueError("model is required") | |||
| def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| model_name = model.name | |||
| provider_name = model.provider | |||
| @@ -4,9 +4,9 @@ from typing import Any, Literal, Union | |||
| from core.file import File | |||
| from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import ListOperatorNodeData | |||
| from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError | |||
| @@ -7,6 +7,8 @@ from datetime import UTC, datetime | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| import json_repair | |||
| from sqlalchemy import select, update | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| @@ -43,6 +45,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.variables import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| @@ -53,9 +56,10 @@ from core.variables import ( | |||
| StringSegment, | |||
| ) | |||
| from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import InNodeEvent | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -77,7 +81,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation | |||
| from models.provider import Provider, ProviderType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import ( | |||
| LLMNodeChatModelMessage, | |||
| @@ -267,9 +270,9 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, | |||
| }, | |||
| llm_usage=usage, | |||
| ) | |||
| @@ -302,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| prompt_messages: Sequence[PromptMessage], | |||
| stop: Optional[Sequence[str]] = None, | |||
| ) -> Generator[NodeEvent, None, None]: | |||
| db.session.close() | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters=node_data_model.completion_params, | |||
| @@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) | |||
| elif isinstance(context_value_variable, ArraySegment): | |||
| context_str = "" | |||
| original_retriever_resource = [] | |||
| original_retriever_resource: list[RetrievalSourceMetadata] = [] | |||
| for item in context_value_variable.value: | |||
| if isinstance(item, str): | |||
| context_str += item + "\n" | |||
| @@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| retriever_resources=original_retriever_resource, context=context_str.strip() | |||
| ) | |||
| def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: | |||
| def _convert_to_original_retriever_resource(self, context_dict: dict): | |||
| if ( | |||
| "metadata" in context_dict | |||
| and "_source" in context_dict["metadata"] | |||
| @@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| ): | |||
| metadata = context_dict.get("metadata", {}) | |||
| source = { | |||
| "position": metadata.get("position"), | |||
| "dataset_id": metadata.get("dataset_id"), | |||
| "dataset_name": metadata.get("dataset_name"), | |||
| "document_id": metadata.get("document_id"), | |||
| "document_name": metadata.get("document_name"), | |||
| "data_source_type": metadata.get("data_source_type"), | |||
| "segment_id": metadata.get("segment_id"), | |||
| "retriever_from": metadata.get("retriever_from"), | |||
| "score": metadata.get("score"), | |||
| "hit_count": metadata.get("segment_hit_count"), | |||
| "word_count": metadata.get("segment_word_count"), | |||
| "segment_position": metadata.get("segment_position"), | |||
| "index_node_hash": metadata.get("segment_index_node_hash"), | |||
| "content": context_dict.get("content"), | |||
| "page": metadata.get("page"), | |||
| "doc_metadata": metadata.get("doc_metadata"), | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=metadata.get("position"), | |||
| dataset_id=metadata.get("dataset_id"), | |||
| dataset_name=metadata.get("dataset_name"), | |||
| document_id=metadata.get("document_id"), | |||
| document_name=metadata.get("document_name"), | |||
| data_source_type=metadata.get("data_source_type"), | |||
| segment_id=metadata.get("segment_id"), | |||
| retriever_from=metadata.get("retriever_from"), | |||
| score=metadata.get("score"), | |||
| hit_count=metadata.get("segment_hit_count"), | |||
| word_count=metadata.get("segment_word_count"), | |||
| segment_position=metadata.get("segment_position"), | |||
| index_node_hash=metadata.get("segment_index_node_hash"), | |||
| content=context_dict.get("content"), | |||
| page=metadata.get("page"), | |||
| doc_metadata=metadata.get("doc_metadata"), | |||
| ) | |||
| return source | |||
| @@ -602,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| return None | |||
| conversation_id = conversation_id_variable.value | |||
| # get conversation | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| return None | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) | |||
| conversation = session.scalar(stmt) | |||
| if not conversation: | |||
| return None | |||
| memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) | |||
| @@ -846,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| used_quota = 1 | |||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ).update( | |||
| { | |||
| "quota_used": Provider.quota_used + used_quota, | |||
| "last_used": datetime.now(tz=UTC).replace(tzinfo=None), | |||
| } | |||
| ) | |||
| db.session.commit() | |||
| with Session(db.engine) as session: | |||
| stmt = ( | |||
| update(Provider) | |||
| .where( | |||
| Provider.tenant_id == tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ) | |||
| .values( | |||
| quota_used=Provider.quota_used + used_quota, | |||
| last_used=datetime.now(tz=UTC).replace(tzinfo=None), | |||
| ) | |||
| ) | |||
| session.execute(stmt) | |||
| session.commit() | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| @@ -1,8 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.loop.entities import LoopEndNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class LoopEndNode(BaseNode[LoopEndNodeData]): | |||
| @@ -15,7 +15,8 @@ from core.variables import ( | |||
| SegmentType, | |||
| StringSegment, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseGraphEvent, | |||
| BaseNodeEvent, | |||
| @@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from core.workflow.nodes.loop.entities import LoopNodeData | |||
| from core.workflow.utils.condition.processor import ConditionProcessor | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| if TYPE_CHECKING: | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| outputs=self.node_data.outputs, | |||
| steps=loop_count, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| "completed_reason": "loop_break" if check_break_result else "loop_completed", | |||
| NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| }, | |||
| ) | |||
| @@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| }, | |||
| outputs=self.node_data.outputs, | |||
| inputs=inputs, | |||
| @@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| metadata={ | |||
| "total_tokens": graph_engine.graph_runtime_state.total_tokens, | |||
| "completed_reason": "error", | |||
| NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| }, | |||
| error=str(e), | |||
| ) | |||
| @@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, | |||
| }, | |||
| ) | |||
| ) | |||
| @@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| inputs=inputs, | |||
| steps=current_index, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: ( | |||
| graph_engine.graph_runtime_state.total_tokens | |||
| ), | |||
| "completed_reason": "error", | |||
| }, | |||
| error=event.error, | |||
| @@ -331,7 +333,11 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, | |||
| metadata={ | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: ( | |||
| graph_engine.graph_runtime_state.total_tokens | |||
| ) | |||
| }, | |||
| ) | |||
| ) | |||
| return {"check_break_result": True} | |||
| @@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| inputs=inputs, | |||
| steps=current_index, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| "completed_reason": "error", | |||
| }, | |||
| error=event.error, | |||
| @@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, | |||
| metadata={ | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens | |||
| }, | |||
| ) | |||
| ) | |||
| return {"check_break_result": True} | |||
| @@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| metadata = event.route_node_state.node_run_result.metadata | |||
| if not metadata: | |||
| metadata = {} | |||
| if NodeRunMetadataKey.LOOP_ID not in metadata: | |||
| if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata: | |||
| metadata = { | |||
| **metadata, | |||
| NodeRunMetadataKey.LOOP_ID: self.node_id, | |||
| NodeRunMetadataKey.LOOP_INDEX: iter_run_index, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index, | |||
| } | |||
| event.route_node_state.node_run_result.metadata = metadata | |||
| return event | |||
| @@ -1,8 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.loop.entities import LoopStartNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class LoopStartNode(BaseNode[LoopStartNodeData]): | |||
| @@ -25,13 +25,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.llm import LLMNode, ModelConfig | |||
| from core.workflow.utils import variable_template_parser | |||
| from extensions.ext_database import db | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import ParameterExtractorNodeData | |||
| from .exc import ( | |||
| @@ -244,9 +243,9 @@ class ParameterExtractorNode(LLMNode): | |||
| process_data=process_data, | |||
| outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, | |||
| }, | |||
| llm_usage=usage, | |||
| ) | |||
| @@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode): | |||
| tools: list[PromptMessageTool], | |||
| stop: list[str], | |||
| ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: | |||
| db.session.close() | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=node_data_model.completion_params, | |||
| @@ -816,7 +813,6 @@ class ParameterExtractorNode(LLMNode): | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| # FIXME: fix the type error later | |||
| variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} | |||
| if node_data.instruction: | |||
| @@ -10,7 +10,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import ModelInvokeCompletedEvent | |||
| from core.workflow.nodes.llm import ( | |||
| @@ -20,7 +21,6 @@ from core.workflow.nodes.llm import ( | |||
| ) | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import QuestionClassifierNodeData | |||
| from .exc import InvalidModelTypeError | |||
| @@ -79,9 +79,13 @@ class QuestionClassifierNode(LLMNode): | |||
| memory=memory, | |||
| max_token_limit=rest_token, | |||
| ) | |||
| # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). | |||
| # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, | |||
| # two consecutive user prompts will be generated, causing model's error. | |||
| # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| sys_query=query, | |||
| sys_query="", | |||
| memory=memory, | |||
| model_config=model_config, | |||
| sys_files=files, | |||
| @@ -142,9 +146,9 @@ class QuestionClassifierNode(LLMNode): | |||
| outputs=outputs, | |||
| edge_source_handle=category_id, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, | |||
| }, | |||
| llm_usage=usage, | |||
| ) | |||
| @@ -154,9 +158,9 @@ class QuestionClassifierNode(LLMNode): | |||
| inputs=variables, | |||
| error=str(e), | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, | |||
| }, | |||
| llm_usage=usage, | |||
| ) | |||
| @@ -1,9 +1,9 @@ | |||
| from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.start.entities import StartNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class StartNode(BaseNode[StartNodeData]): | |||
| @@ -4,10 +4,10 @@ from typing import Any, Optional | |||
| from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) | |||
| @@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.variables.segments import ArrayAnySegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import AgentLogEvent | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import ToolFile | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||
| from .entities import ToolNodeData | |||
| @@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs={}, | |||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to get tool runtime: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to invoke tool: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to transform tool message: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| json: list[dict] = [] | |||
| agent_logs: list[AgentLogEvent] = [] | |||
| agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} | |||
| agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} | |||
| variables: dict[str, Any] = {} | |||
| @@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| agent_execution_metadata = { | |||
| key: value | |||
| for key, value in msg_metadata.items() | |||
| if key in NodeRunMetadataKey.__members__.values() | |||
| if key in WorkflowNodeExecutionMetadataKey.__members__.values() | |||
| } | |||
| json.append(message.message.json_object) | |||
| elif message.type == ToolInvokeMessage.MessageType.LINK: | |||
| @@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| outputs={"text": text, "files": files, "json": json, **variables}, | |||
| metadata={ | |||
| **agent_execution_metadata, | |||
| NodeRunMetadataKey.TOOL_INFO: tool_info, | |||
| NodeRunMetadataKey.AGENT_LOG: agent_logs, | |||
| WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, | |||
| WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, | |||
| }, | |||
| inputs=parameters_for_log, | |||
| ) | |||
| @@ -1,7 +1,8 @@ | |||
| from typing import Literal, Optional | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.nodes.base import BaseNodeData | |||
| @@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel): | |||
| Group. | |||
| """ | |||
| output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] | |||
| output_type: SegmentType | |||
| variables: list[list[str]] | |||
| group_name: str | |||
| @@ -1,8 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): | |||
| @@ -1,11 +1,11 @@ | |||
| from core.variables import SegmentType, Variable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.variable_assigner.common import helpers as common_helpers | |||
| from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError | |||
| from factories import variable_factory | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .node_data import VariableAssignerData, WriteMode | |||
| @@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.variables import SegmentType, Variable | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.variable_assigner.common import helpers as common_helpers | |||
| from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from . import helpers | |||
| from .constants import EMPTY_VALUE_MAPPING | |||
| @@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying | |||
| storage mechanism. | |||
| """ | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "OrderConfig", | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Optional, Protocol | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| class WorkflowExecutionRepository(Protocol): | |||
| @@ -2,7 +2,7 @@ from collections.abc import Sequence | |||
| from dataclasses import dataclass | |||
| from typing import Literal, Optional, Protocol | |||
| from core.workflow.entities.node_execution_entities import NodeExecution | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution | |||
| @dataclass | |||
| @@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| application domains or deployment scenarios. | |||
| """ | |||
| def save(self, execution: NodeExecution) -> None: | |||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||
| """ | |||
| Save or update a NodeExecution instance. | |||
| @@ -39,7 +39,7 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| """ | |||
| ... | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve a NodeExecution by its node_execution_id. | |||
| @@ -55,7 +55,7 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| self, | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| ) -> Sequence[NodeExecution]: | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all NodeExecution instances for a specific workflow run. | |||
| @@ -70,7 +70,7 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| """ | |||
| ... | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||
| @@ -1,11 +1,9 @@ | |||
| from collections.abc import Mapping | |||
| from dataclasses import dataclass | |||
| from datetime import UTC, datetime | |||
| from typing import Any, Optional, Union | |||
| from uuid import uuid4 | |||
| from sqlalchemy import func, select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity | |||
| from core.app.entities.queue_entities import ( | |||
| QueueNodeExceptionEvent, | |||
| @@ -19,21 +17,24 @@ from core.app.entities.queue_entities import ( | |||
| from core.app.task_pipeline.exc import WorkflowRunNotFoundError | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.entities.workflow_node_execution import ( | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionMetadataKey, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from models import ( | |||
| Workflow, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| @dataclass | |||
| class CycleManagerWorkflowInfo: | |||
| workflow_id: str | |||
| workflow_type: WorkflowType | |||
| version: str | |||
| graph_data: Mapping[str, Any] | |||
| class WorkflowCycleManager: | |||
| @@ -42,32 +43,17 @@ class WorkflowCycleManager: | |||
| *, | |||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||
| workflow_system_variables: dict[SystemVariableKey, Any], | |||
| workflow_info: CycleManagerWorkflowInfo, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_system_variables = workflow_system_variables | |||
| self._workflow_info = workflow_info | |||
| self._workflow_execution_repository = workflow_execution_repository | |||
| self._workflow_node_execution_repository = workflow_node_execution_repository | |||
| def handle_workflow_run_start( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_id: str, | |||
| ) -> WorkflowExecution: | |||
| workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | |||
| workflow = session.scalar(workflow_stmt) | |||
| if not workflow: | |||
| raise ValueError(f"Workflow not found: {workflow_id}") | |||
| max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( | |||
| WorkflowRun.tenant_id == workflow.tenant_id, | |||
| WorkflowRun.app_id == workflow.app_id, | |||
| ) | |||
| max_sequence = session.scalar(max_sequence_stmt) or 0 | |||
| new_sequence_number = max_sequence + 1 | |||
| def handle_workflow_run_start(self) -> WorkflowExecution: | |||
| inputs = {**self._application_generate_entity.inputs} | |||
| for key, value in (self._workflow_system_variables or {}).items(): | |||
| if key.value == "conversation": | |||
| @@ -79,14 +65,13 @@ class WorkflowCycleManager: | |||
| # init workflow run | |||
| # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this | |||
| execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) | |||
| execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) | |||
| execution = WorkflowExecution.new( | |||
| id=execution_id, | |||
| workflow_id=workflow.id, | |||
| sequence_number=new_sequence_number, | |||
| type=WorkflowType(workflow.type), | |||
| workflow_version=workflow.version, | |||
| graph=workflow.graph_dict, | |||
| id_=execution_id, | |||
| workflow_id=self._workflow_info.workflow_id, | |||
| workflow_type=self._workflow_info.workflow_type, | |||
| workflow_version=self._workflow_info.version, | |||
| graph=self._workflow_info.graph_data, | |||
| inputs=inputs, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| @@ -168,7 +153,7 @@ class WorkflowCycleManager: | |||
| workflow_run_id: str, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| status: WorkflowRunStatus, | |||
| status: WorkflowExecutionStatus, | |||
| error_message: str, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| @@ -185,7 +170,7 @@ class WorkflowCycleManager: | |||
| # Use the instance repository to find running executions for a workflow run | |||
| running_node_executions = self._workflow_node_execution_repository.get_running_executions( | |||
| workflow_run_id=workflow_execution.id | |||
| workflow_run_id=workflow_execution.id_ | |||
| ) | |||
| # Update the domain models | |||
| @@ -193,7 +178,7 @@ class WorkflowCycleManager: | |||
| for node_execution in running_node_executions: | |||
| if node_execution.node_execution_id: | |||
| # Update the domain model | |||
| node_execution.status = NodeExecutionStatus.FAILED | |||
| node_execution.status = WorkflowNodeExecutionStatus.FAILED | |||
| node_execution.error = error_message | |||
| node_execution.finished_at = now | |||
| node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() | |||
| @@ -219,28 +204,28 @@ class WorkflowCycleManager: | |||
| *, | |||
| workflow_execution_id: str, | |||
| event: QueueNodeStartedEvent, | |||
| ) -> NodeExecution: | |||
| ) -> WorkflowNodeExecution: | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) | |||
| # Create a domain model | |||
| created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| metadata = { | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||
| WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, | |||
| } | |||
| domain_execution = NodeExecution( | |||
| domain_execution = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_execution.workflow_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| workflow_execution_id=workflow_execution.id_, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| index=event.node_run_index, | |||
| node_execution_id=event.node_execution_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| title=event.node_data.title, | |||
| status=NodeExecutionStatus.RUNNING, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| metadata=metadata, | |||
| created_at=created_at, | |||
| ) | |||
| @@ -250,7 +235,7 @@ class WorkflowCycleManager: | |||
| return domain_execution | |||
| def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: | |||
| def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: | |||
| # Get the domain model from repository | |||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||
| if not domain_execution: | |||
| @@ -271,7 +256,7 @@ class WorkflowCycleManager: | |||
| elapsed_time = (finished_at - event.start_at).total_seconds() | |||
| # Update domain model | |||
| domain_execution.status = NodeExecutionStatus.SUCCEEDED | |||
| domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED | |||
| domain_execution.update_from_mapping( | |||
| inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict | |||
| ) | |||
| @@ -290,7 +275,7 @@ class WorkflowCycleManager: | |||
| | QueueNodeInIterationFailedEvent | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| ) -> NodeExecution: | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param event: queue node failed event | |||
| @@ -317,9 +302,9 @@ class WorkflowCycleManager: | |||
| # Update domain model | |||
| domain_execution.status = ( | |||
| NodeExecutionStatus.FAILED | |||
| WorkflowNodeExecutionStatus.FAILED | |||
| if not isinstance(event, QueueNodeExceptionEvent) | |||
| else NodeExecutionStatus.EXCEPTION | |||
| else WorkflowNodeExecutionStatus.EXCEPTION | |||
| ) | |||
| domain_execution.error = event.error | |||
| domain_execution.update_from_mapping( | |||
| @@ -335,7 +320,7 @@ class WorkflowCycleManager: | |||
| def handle_workflow_node_execution_retried( | |||
| self, *, workflow_execution_id: str, event: QueueNodeRetryEvent | |||
| ) -> NodeExecution: | |||
| ) -> WorkflowNodeExecution: | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) | |||
| created_at = event.start_at | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| @@ -345,13 +330,13 @@ class WorkflowCycleManager: | |||
| # Convert metadata keys to strings | |||
| origin_metadata = { | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, | |||
| } | |||
| # Convert execution metadata keys to strings | |||
| execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} | |||
| execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} | |||
| if event.execution_metadata: | |||
| for key, value in event.execution_metadata.items(): | |||
| execution_metadata_dict[key] = value | |||
| @@ -359,16 +344,16 @@ class WorkflowCycleManager: | |||
| merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata | |||
| # Create a domain model | |||
| domain_execution = NodeExecution( | |||
| domain_execution = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_execution.workflow_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| workflow_execution_id=workflow_execution.id_, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| node_execution_id=event.node_execution_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| title=event.node_data.title, | |||
| status=NodeExecutionStatus.RETRY, | |||
| status=WorkflowNodeExecutionStatus.RETRY, | |||
| created_at=created_at, | |||
| finished_at=finished_at, | |||
| elapsed_time=elapsed_time, | |||
| @@ -93,8 +93,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen | |||
| raise VariableError("missing value type") | |||
| if (value := mapping.get("value")) is None: | |||
| raise VariableError("missing value") | |||
| # FIXME: using Any here, fix it later | |||
| result: Any | |||
| result: Variable | |||
| match value_type: | |||
| case SegmentType.STRING: | |||
| result = StringVariable.model_validate(mapping) | |||
| @@ -28,7 +28,8 @@ class SMTPClient: | |||
| else: | |||
| smtp = smtplib.SMTP(self.server, self.port, timeout=10) | |||
| if self.username and self.password: | |||
| # Only authenticate if both username and password are non-empty | |||
| if self.username and self.password and self.username.strip() and self.password.strip(): | |||
| smtp.login(self.username, self.password) | |||
| msg = MIMEMultipart() | |||
| @@ -85,11 +85,9 @@ from .workflow import ( | |||
| Workflow, | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowNodeExecutionModel, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| WorkflowType, | |||
| ) | |||
| @@ -101,14 +99,14 @@ __all__ = [ | |||
| "AccountStatus", | |||
| "ApiRequest", | |||
| "ApiToken", | |||
| "ApiToolProvider", # Added | |||
| "ApiToolProvider", | |||
| "App", | |||
| "AppAnnotationHitHistory", | |||
| "AppAnnotationSetting", | |||
| "AppDatasetJoin", | |||
| "AppMode", | |||
| "AppModelConfig", | |||
| "BuiltinToolProvider", # Added | |||
| "BuiltinToolProvider", | |||
| "CeleryTask", | |||
| "CeleryTaskSet", | |||
| "Conversation", | |||
| @@ -174,11 +172,9 @@ __all__ = [ | |||
| "Workflow", | |||
| "WorkflowAppLog", | |||
| "WorkflowAppLogCreatedFrom", | |||
| "WorkflowNodeExecution", | |||
| "WorkflowNodeExecutionStatus", | |||
| "WorkflowNodeExecutionModel", | |||
| "WorkflowNodeExecutionTriggeredFrom", | |||
| "WorkflowRun", | |||
| "WorkflowRunStatus", | |||
| "WorkflowRunTriggeredFrom", | |||
| "WorkflowToolProvider", | |||
| "WorkflowType", | |||
| @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast | |||
| from core.plugin.entities.plugin import GenericProviderID | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.tools.signature import sign_tool_file | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from services.plugin.plugin_service import PluginService | |||
| if TYPE_CHECKING: | |||
| @@ -31,7 +32,6 @@ from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatorUserRole | |||
| from .types import StringUUID | |||
| from .workflow import WorkflowRunStatus | |||
| if TYPE_CHECKING: | |||
| from .workflow import Workflow | |||
| @@ -795,22 +795,22 @@ class Conversation(Base): | |||
| def status_count(self): | |||
| messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() | |||
| status_counts = { | |||
| WorkflowRunStatus.RUNNING: 0, | |||
| WorkflowRunStatus.SUCCEEDED: 0, | |||
| WorkflowRunStatus.FAILED: 0, | |||
| WorkflowRunStatus.STOPPED: 0, | |||
| WorkflowRunStatus.PARTIAL_SUCCEEDED: 0, | |||
| WorkflowExecutionStatus.RUNNING: 0, | |||
| WorkflowExecutionStatus.SUCCEEDED: 0, | |||
| WorkflowExecutionStatus.FAILED: 0, | |||
| WorkflowExecutionStatus.STOPPED: 0, | |||
| WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, | |||
| } | |||
| for message in messages: | |||
| if message.workflow_run: | |||
| status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1 | |||
| status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 | |||
| return ( | |||
| { | |||
| "success": status_counts[WorkflowRunStatus.SUCCEEDED], | |||
| "failed": status_counts[WorkflowRunStatus.FAILED], | |||
| "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCEEDED], | |||
| "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], | |||
| "failed": status_counts[WorkflowExecutionStatus.FAILED], | |||
| "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], | |||
| } | |||
| if messages | |||
| else None | |||
| @@ -401,18 +401,6 @@ class Workflow(Base): | |||
| ) | |||
| class WorkflowRunStatus(StrEnum): | |||
| """ | |||
| Workflow Run Status Enum | |||
| """ | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| STOPPED = "stopped" | |||
| PARTIAL_SUCCEEDED = "partial-succeeded" | |||
| class WorkflowRun(Base): | |||
| """ | |||
| Workflow Run | |||
| @@ -473,12 +461,12 @@ class WorkflowRun(Base): | |||
| error: Mapped[Optional[str]] = mapped_column(db.Text) | |||
| elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) | |||
| total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) | |||
| total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) | |||
| total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) | |||
| created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | |||
| created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||
| exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) | |||
| exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) | |||
| @property | |||
| def created_by_account(self): | |||
| @@ -578,19 +566,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): | |||
| RAG_PIPELINE_RUN = "rag-pipeline-run" | |||
| class WorkflowNodeExecutionStatus(StrEnum): | |||
| """ | |||
| Workflow Node Execution Status Enum | |||
| """ | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| EXCEPTION = "exception" | |||
| RETRY = "retry" | |||
| class WorkflowNodeExecution(Base): | |||
| class WorkflowNodeExecutionModel(Base): | |||
| """ | |||
| Workflow Node Execution | |||
| @@ -14,7 +14,7 @@ dependencies = [ | |||
| "chardet~=5.1.0", | |||
| "flask~=3.1.0", | |||
| "flask-compress~=1.17", | |||
| "flask-cors~=5.0.0", | |||
| "flask-cors~=6.0.0", | |||
| "flask-login~=0.6.3", | |||
| "flask-migrate~=4.0.7", | |||
| "flask-restful~=0.3.10", | |||
| @@ -36,7 +36,6 @@ dependencies = [ | |||
| "mailchimp-transactional~=1.0.50", | |||
| "markdown~=3.5.1", | |||
| "numpy~=1.26.4", | |||
| "oci~=2.135.1", | |||
| "openai~=1.61.0", | |||
| "openpyxl~=3.1.5", | |||
| "opik~=1.7.25", | |||
| @@ -143,13 +142,16 @@ dev = [ | |||
| "types-requests~=2.32.0", | |||
| "types-requests-oauthlib~=2.0.0", | |||
| "types-shapely~=2.0.0", | |||
| "types-simplejson~=3.20.0", | |||
| "types-six~=1.17.0", | |||
| "types-tensorflow~=2.18.0", | |||
| "types-tqdm~=4.67.0", | |||
| "types-ujson~=5.10.0", | |||
| "types-simplejson>=3.20.0", | |||
| "types-six>=1.17.0", | |||
| "types-tensorflow>=2.18.0", | |||
| "types-tqdm>=4.67.0", | |||
| "types-ujson>=5.10.0", | |||
| "boto3-stubs>=1.38.20", | |||
| "types-jmespath>=1.0.2.20240106", | |||
| "types_pyOpenSSL>=24.1.0", | |||
| "types_cffi>=1.17.0", | |||
| "types_setuptools>=80.9.0", | |||
| ] | |||
| ############################################################ | |||
| @@ -1,5 +1,4 @@ | |||
| [pytest] | |||
| continue-on-collection-errors = true | |||
| addopts = --cov=./api --cov-report=json --cov-report=xml | |||
| env = | |||
| ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz | |||
| @@ -34,9 +34,8 @@ def clean_messages(): | |||
| while True: | |||
| try: | |||
| # Main query with join and filter | |||
| # FIXME:for mypy no paginate method error | |||
| messages = ( | |||
| db.session.query(Message) # type: ignore | |||
| db.session.query(Message) | |||
| .filter(Message.created_at < plan_sandbox_clean_message_day) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(100) | |||
| @@ -14,7 +14,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.account import Tenant | |||
| from models.model import App, Conversation, Message | |||
| from models.workflow import WorkflowNodeExecution, WorkflowRun | |||
| from models.workflow import WorkflowNodeExecutionModel, WorkflowRun | |||
| from services.billing_service import BillingService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs: | |||
| while True: | |||
| with Session(db.engine).no_autoflush as session: | |||
| workflow_node_executions = ( | |||
| session.query(WorkflowNodeExecution) | |||
| session.query(WorkflowNodeExecutionModel) | |||
| .filter( | |||
| WorkflowNodeExecution.tenant_id == tenant_id, | |||
| WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days), | |||
| WorkflowNodeExecutionModel.tenant_id == tenant_id, | |||
| WorkflowNodeExecutionModel.created_at | |||
| < datetime.datetime.now() - datetime.timedelta(days=days), | |||
| ) | |||
| .limit(batch) | |||
| .all() | |||
| @@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs: | |||
| ] | |||
| # delete workflow node executions | |||
| session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id.in_(workflow_node_execution_ids), | |||
| session.query(WorkflowNodeExecutionModel).filter( | |||
| WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), | |||
| ).delete(synchronize_session=False) | |||
| session.commit() | |||
| @@ -2,8 +2,11 @@ import logging | |||
| import time | |||
| from typing import Any | |||
| from core.app.app_config.entities import ModelConfig | |||
| from core.model_runtime.entities import LLMMode | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| @@ -34,7 +37,29 @@ class HitTestingService: | |||
| # get retrieval model , if the model is not setting , using default | |||
| if not retrieval_model: | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| document_ids_filter = None | |||
| metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) | |||
| if metadata_filtering_conditions: | |||
| dataset_retrieval = DatasetRetrieval() | |||
| from core.app.app_config.entities import MetadataFilteringCondition | |||
| metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) | |||
| metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( | |||
| dataset_ids=[dataset.id], | |||
| query=query, | |||
| metadata_filtering_mode="manual", | |||
| metadata_filtering_conditions=metadata_filtering_conditions, | |||
| inputs={}, | |||
| tenant_id="", | |||
| user_id="", | |||
| metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}), | |||
| ) | |||
| if metadata_filter_document_ids: | |||
| document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) | |||
| if metadata_condition and not document_ids_filter: | |||
| return cls.compact_retrieve_response(query, []) | |||
| all_documents = RetrievalService.retrieve( | |||
| retrieval_method=retrieval_model.get("search_method", "semantic_search"), | |||
| dataset_id=dataset.id, | |||
| @@ -48,6 +73,7 @@ class HitTestingService: | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| end = time.perf_counter() | |||
| @@ -99,7 +125,7 @@ class HitTestingService: | |||
| return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) | |||
| @classmethod | |||
| def compact_retrieve_response(cls, query: str, documents: list[Document]): | |||
| def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]: | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| return { | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| from core.ops.entities.config_entity import BaseTracingConfig | |||
| from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map | |||
| from extensions.ext_database import db | |||
| from models.model import App, TraceAppConfig | |||
| @@ -92,13 +93,12 @@ class OpsService: | |||
| except KeyError: | |||
| return {"error": f"Invalid tracing provider: {tracing_provider}"} | |||
| config_class, other_keys = ( | |||
| provider_config_map[tracing_provider]["config_class"], | |||
| provider_config_map[tracing_provider]["other_keys"], | |||
| ) | |||
| # FIXME: ignore type error | |||
| default_config_instance = config_class(**tracing_config) # type: ignore | |||
| for key in other_keys: # type: ignore | |||
| provider_config: dict[str, Any] = provider_config_map[tracing_provider] | |||
| config_class: type[BaseTracingConfig] = provider_config["config_class"] | |||
| other_keys: list[str] = provider_config["other_keys"] | |||
| default_config_instance: BaseTracingConfig = config_class(**tracing_config) | |||
| for key in other_keys: | |||
| if key in tracing_config and tracing_config[key] == "": | |||
| tracing_config[key] = getattr(default_config_instance, key, None) | |||