Browse Source

Merge branch 'main' into feat/r2

# Conflicts:
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/enums.py
tags/2.0.0-beta.1
jyong 5 months ago
parent
commit
309fffd1e4
100 changed files with 991 additions and 849 deletions
  1. 1
    1
      .devcontainer/post_create_command.sh
  2. 3
    0
      api/commands.py
  3. 1
    2
      api/configs/remote_settings_sources/nacos/http_request.py
  4. 2
    2
      api/controllers/console/app/workflow_app_log.py
  5. 3
    2
      api/controllers/service_api/app/workflow.py
  6. 133
    2
      api/controllers/service_api/dataset/dataset.py
  7. 22
    0
      api/controllers/service_api/dataset/segment.py
  8. 1
    1
      api/core/app/app_config/easy_ui_based_app/model_config/converter.py
  9. 2
    2
      api/core/app/apps/advanced_chat/app_generator.py
  10. 1
    1
      api/core/app/apps/advanced_chat/app_runner.py
  11. 34
    43
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  12. 16
    19
      api/core/app/apps/common/workflow_response_converter.py
  13. 5
    5
      api/core/app/apps/workflow/app_generator.py
  14. 1
    1
      api/core/app/apps/workflow/app_runner.py
  15. 21
    24
      api/core/app/apps/workflow/generate_task_pipeline.py
  16. 2
    2
      api/core/app/apps/workflow_app_runner.py
  17. 3
    4
      api/core/app/entities/app_invoke_entities.py
  18. 11
    9
      api/core/app/entities/queue_entities.py
  19. 24
    9
      api/core/app/entities/task_entities.py
  20. 22
    23
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  21. 17
    12
      api/core/app/task_pipeline/message_cycle_manager.py
  22. 4
    1
      api/core/callback_handler/index_tool_callback_handler.py
  23. 2
    1
      api/core/helper/code_executor/code_executor.py
  24. 7
    8
      api/core/helper/marketplace.py
  25. 13
    54
      api/core/llm_generator/prompts.py
  26. 0
    13
      api/core/model_runtime/entities/llm_entities.py
  27. 12
    11
      api/core/model_runtime/utils/encoders.py
  28. 39
    0
      api/core/ops/base_trace_instance.py
  29. 8
    5
      api/core/ops/entities/trace_entity.py
  30. 7
    18
      api/core/ops/langfuse_trace/langfuse_trace.py
  31. 8
    19
      api/core/ops/langsmith_trace/langsmith_trace.py
  32. 8
    19
      api/core/ops/opik_trace/opik_trace.py
  33. 3
    2
      api/core/ops/ops_trace_manager.py
  34. 8
    19
      api/core/ops/weave_trace/weave_trace.py
  35. 3
    4
      api/core/plugin/impl/base.py
  36. 0
    1
      api/core/rag/datasource/vdb/baidu/baidu_vector.py
  37. 1
    1
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  38. 4
    0
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  39. 1
    1
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
  40. 23
    0
      api/core/rag/entities/citation_metadata.py
  41. 2
    6
      api/core/rag/extractor/entity/extract_setting.py
  42. 2
    3
      api/core/rag/models/document.py
  43. 35
    31
      api/core/rag/retrieval/dataset_retrieval.py
  44. 22
    8
      api/core/repositories/sqlalchemy_workflow_execution_repository.py
  45. 38
    39
      api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
  46. 1
    1
      api/core/tools/custom_tool/tool.py
  47. 0
    1
      api/core/tools/entities/tool_entities.py
  48. 20
    19
      api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
  49. 37
    36
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
  50. 0
    1
      api/core/tools/utils/message_transformer.py
  51. 7
    0
      api/core/tools/utils/parser.py
  52. 2
    29
      api/core/workflow/entities/node_entities.py
  53. 6
    10
      api/core/workflow/entities/workflow_execution.py
  54. 31
    7
      api/core/workflow/entities/workflow_node_execution.py
  55. 1
    1
      api/core/workflow/enums.py
  56. 3
    2
      api/core/workflow/graph_engine/entities/event.py
  57. 1
    1
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  58. 17
    16
      api/core/workflow/graph_engine/graph_engine.py
  59. 9
    9
      api/core/workflow/nodes/agent/agent_node.py
  60. 1
    1
      api/core/workflow/nodes/answer/answer_node.py
  61. 1
    1
      api/core/workflow/nodes/base/node.py
  62. 1
    1
      api/core/workflow/nodes/code/code_node.py
  63. 1
    1
      api/core/workflow/nodes/document_extractor/node.py
  64. 1
    1
      api/core/workflow/nodes/end/end_node.py
  65. 4
    2
      api/core/workflow/nodes/event/event.py
  66. 1
    1
      api/core/workflow/nodes/http_request/node.py
  67. 1
    1
      api/core/workflow/nodes/if_else/if_else_node.py
  68. 7
    8
      api/core/workflow/nodes/iteration/iteration_node.py
  69. 1
    1
      api/core/workflow/nodes/iteration/iteration_start_node.py
  70. 16
    19
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  71. 1
    1
      api/core/workflow/nodes/list_operator/node.py
  72. 51
    50
      api/core/workflow/nodes/llm/node.py
  73. 1
    1
      api/core/workflow/nodes/loop/loop_end_node.py
  74. 28
    20
      api/core/workflow/nodes/loop/loop_node.py
  75. 1
    1
      api/core/workflow/nodes/loop/loop_start_node.py
  76. 5
    9
      api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
  77. 13
    9
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  78. 1
    1
      api/core/workflow/nodes/start/start_node.py
  79. 1
    1
      api/core/workflow/nodes/template_transform/template_transform_node.py
  80. 9
    9
      api/core/workflow/nodes/tool/tool_node.py
  81. 3
    2
      api/core/workflow/nodes/variable_aggregator/entities.py
  82. 1
    1
      api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
  83. 1
    1
      api/core/workflow/nodes/variable_assigner/v1/node.py
  84. 1
    1
      api/core/workflow/nodes/variable_assigner/v2/node.py
  85. 1
    1
      api/core/workflow/repositories/__init__.py
  86. 1
    1
      api/core/workflow/repositories/workflow_execution_repository.py
  87. 5
    5
      api/core/workflow/repositories/workflow_node_execution_repository.py
  88. 48
    63
      api/core/workflow/workflow_cycle_manager.py
  89. 2
    2
      api/factories/variable_factory.py
  90. 2
    1
      api/libs/smtp.py
  91. 4
    8
      api/models/__init__.py
  92. 10
    10
      api/models/model.py
  93. 3
    27
      api/models/workflow.py
  94. 9
    7
      api/pyproject.toml
  95. 0
    1
      api/pytest.ini
  96. 1
    2
      api/schedule/clean_messages.py
  97. 7
    6
      api/services/clear_free_plan_tenant_expired_logs.py
  98. 28
    2
      api/services/hit_testing_service.py
  99. 8
    8
      api/services/ops_service.py
  100. 0
    0
      api/services/tag_service.py

+ 1
- 1
.devcontainer/post_create_command.sh View File

@@ -1,6 +1,6 @@
#!/bin/bash

npm add -g pnpm@10.8.0
npm add -g pnpm@10.11.1
cd web && pnpm install
pipx install uv


+ 3
- 0
api/commands.py View File

@@ -846,6 +846,9 @@ def clear_orphaned_file_records(force: bool):
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
{"type": "text", "table": "conversations", "column": "introduction"},
{"type": "text", "table": "conversations", "column": "system_instruction"},
{"type": "text", "table": "accounts", "column": "avatar"},
{"type": "text", "table": "apps", "column": "icon"},
{"type": "text", "table": "sites", "column": "icon"},
{"type": "json", "table": "messages", "column": "inputs"},
{"type": "json", "table": "messages", "column": "message"},
]

+ 1
- 2
api/configs/remote_settings_sources/nacos/http_request.py View File

@@ -60,8 +60,7 @@ class NacosHttpClient:
sign_str = tenant + "+"
if group:
sign_str = sign_str + group + "+"
if sign_str:
sign_str += ts
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str

def get_access_token(self, force_refresh=False):

+ 2
- 2
api/controllers/console/app/workflow_app_log.py View File

@@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models import App
from models.model import AppMode
from models.workflow import WorkflowRunStatus
from services.workflow_app_service import WorkflowAppService


@@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()

args.status = WorkflowRunStatus(args.status) if args.status else None
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)


+ 3
- 2
api/controllers/service_api/app/workflow.py View File

@@ -24,12 +24,13 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun, WorkflowRunStatus
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()

args.status = WorkflowRunStatus(args.status) if args.status else None
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)


+ 133
- 2
api/controllers/service_api/dataset/dataset.py View File

@@ -1,19 +1,21 @@
from flask import request
from flask_restful import marshal, reqparse
from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound

import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import tag_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService


def _validate_name(name):
@@ -320,5 +322,134 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError()


class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token
@marshal_with(tag_fields)
def get(self, _, dataset_id):
"""Get all knowledge type tags."""
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)

return tags, 200

@validate_dataset_token
def post(self, _, dataset_id):
"""Add a knowledge type tag."""
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=DatasetTagsApi._validate_tag_name,
)

args = parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)

response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}

return response, 200

@validate_dataset_token
def patch(self, _, dataset_id):
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=DatasetTagsApi._validate_tag_name,
)
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
tag = TagService.update_tags(args, args.get("tag_id"))

binding_count = TagService.get_tag_binding_count(args.get("tag_id"))

response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}

return response, 200

@validate_dataset_token
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
TagService.delete_tag(args.get("tag_id"))

return 204

@staticmethod
def _validate_tag_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name


class DatasetTagBindingApi(DatasetApiResource):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)

args = parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)

return 204


class DatasetTagUnbindingApi(DatasetApiResource):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")

args = parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)

return 204


class DatasetTagsBindingStatusApi(DatasetApiResource):
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)}
return response, 200


api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetTagsApi, "/datasets/tags")
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")

+ 22
- 0
api/controllers/service_api/dataset/segment.py View File

@@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource):
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

def get(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")

return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200


class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""

+ 1
- 1
api/core/app/app_config/easy_ui_based_app/model_config/converter.py View File

@@ -70,7 +70,7 @@ class ModelConfigConverter:
if not model_mode:
model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value

if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

+ 2
- 2
api/core/app/apps/advanced_chat/app_generator.py View File

@@ -27,8 +27,8 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom

+ 1
- 1
api/core/app/apps/advanced_chat/app_runner.py View File

@@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
}

# init variable pool

+ 34
- 43
api/core/app/apps/advanced_chat/generate_task_pipeline.py View File

@@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator, Mapping
@@ -57,26 +56,23 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.enums import CreatorUserRole
from models.workflow import (
Workflow,
WorkflowRunStatus,
)
from models.workflow import Workflow

logger = logging.getLogger(__name__)

@@ -126,8 +122,14 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
},
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
@@ -137,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)

self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)

@@ -158,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)

@@ -302,15 +304,12 @@ class AdvancedChatAppGenerateTaskPipeline:

with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
)
self._workflow_run_id = workflow_execution.id
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_execution.id
message.workflow_run_id = workflow_execution.id_
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
@@ -550,7 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
status=WorkflowExecutionStatus.FAILED,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
@@ -576,7 +575,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
status=WorkflowExecutionStatus.STOPPED,
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
@@ -604,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)

with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager._handle_annotation_reply(event)
self._message_cycle_manager.handle_annotation_reply(event)

with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
@@ -636,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._message_cycle_manager._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=event.text, reason=event.reason
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
@@ -653,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
@@ -682,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
message_id=message.id,
@@ -712,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage)
self._task_state.metadata.usage = usage
else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
@@ -725,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
Message end to stream response.
:return:
"""
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
extras = self._task_state.metadata.model_dump()

if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
if self._task_state.metadata.annotation_reply:
del extras["annotation_reply"]

return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
metadata=extras,
)

def _handle_output_moderation_chunk(self, text: str) -> bool:

+ 16
- 19
api/core/app/apps/common/workflow_response_converter.py View File

@@ -44,15 +44,14 @@ from core.app.entities.task_entities import (
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_execution_entities import NodeExecution
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecutionStatus,
WorkflowRun,
)

@@ -73,11 +72,10 @@ class WorkflowResponseConverter:
) -> WorkflowStartStreamResponse:
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id,
workflow_run_id=workflow_execution.id_,
data=WorkflowStartStreamResponse.Data(
id=workflow_execution.id,
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
),
@@ -91,7 +89,7 @@ class WorkflowResponseConverter:
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
@@ -122,11 +120,10 @@ class WorkflowResponseConverter:

return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id,
workflow_run_id=workflow_execution.id_,
data=WorkflowFinishStreamResponse.Data(
id=workflow_execution.id,
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
error=workflow_execution.error_message,
@@ -146,16 +143,16 @@ class WorkflowResponseConverter:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
if not workflow_node_execution.workflow_execution_id:
return None

response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
@@ -196,18 +193,18 @@ class WorkflowResponseConverter:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
return None

return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
@@ -239,18 +236,18 @@ class WorkflowResponseConverter:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
return None

return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,

+ 5
- 5
api/core/app/apps/workflow/app_generator.py View File

@@ -25,8 +25,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
@@ -132,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
workflow_execution_id=workflow_run_id,
)

contexts.plugin_tool_providers.set({})
@@ -279,7 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@@ -355,7 +355,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

+ 1
- 1
api/core/app/apps/workflow/app_runner.py View File

@@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
}

variable_pool = VariablePool(

+ 21
- 24
api/core/app/apps/workflow/generate_task_pipeline.py View File

@@ -50,16 +50,15 @@ from core.app.entities.task_entities import (
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole
@@ -69,7 +68,6 @@ from models.workflow import (
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowRun,
WorkflowRunStatus,
)

logger = logging.getLogger(__name__)
@@ -114,8 +112,14 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
},
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
@@ -125,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
)

self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""

def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -266,17 +268,13 @@ class WorkflowAppGenerateTaskPipeline:
# override graph runtime state
graph_runtime_state = event.graph_runtime_state

with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
)
self._workflow_run_id = workflow_execution.id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)

yield start_resp
elif isinstance(
@@ -511,9 +509,9 @@ class WorkflowAppGenerateTaskPipeline:
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
else WorkflowExecutionStatus.STOPPED,
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
@@ -542,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(queue_message)

self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
@@ -557,7 +554,7 @@ class WorkflowAppGenerateTaskPipeline:
tts_publisher.publish(None)

def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:

+ 2
- 2
api/core/app/apps/workflow_app_runner.py View File

@@ -29,8 +29,8 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent,
@@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner):
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data

+ 3
- 4
api/core/app/entities/app_invoke_entities.py View File

@@ -77,6 +77,8 @@ class AppGenerateEntity(BaseModel):
App Generate Entity.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

task_id: str

# app config
@@ -100,9 +102,6 @@ class AppGenerateEntity(BaseModel):
# tracing instance
trace_manager: Optional[TraceQueueManager] = None

class Config:
arbitrary_types_allowed = True


class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
@@ -206,7 +205,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):

# app config
app_config: WorkflowUIBasedAppConfig
workflow_run_id: str
workflow_execution_id: str

class SingleIterationRunEntity(BaseModel):
"""

+ 11
- 9
api/core/app/entities/queue_entities.py View File

@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
@@ -6,7 +6,9 @@ from typing import Any, Optional
from pydantic import BaseModel

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""

event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
@@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: Optional[str] = None
"""single iteration duration map"""
@@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: str
retry_index: int # retry index
@@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: str

@@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: str

@@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: str

@@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None

error: str


+ 24
- 9
api/core/app/entities/task_entities.py View File

@@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from models.workflow import WorkflowNodeExecutionStatus
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus


class AnnotationReplyAccount(BaseModel):
id: str
name: str


class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount


class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None


class TaskState(BaseModel):
@@ -15,7 +32,7 @@ class TaskState(BaseModel):
TaskState entity
"""

metadata: dict = {}
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)


class EasyUITaskState(TaskState):
@@ -189,7 +206,6 @@ class WorkflowStartStreamResponse(StreamResponse):

id: str
workflow_id: str
sequence_number: int
inputs: Mapping[str, Any]
created_at: int

@@ -210,7 +226,6 @@ class WorkflowFinishStreamResponse(StreamResponse):

id: str
workflow_id: str
sequence_number: int
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
@@ -307,7 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
@@ -376,7 +391,7 @@ class NodeRetryStreamResponse(StreamResponse):
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []

+ 22
- 23
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py View File

@@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)


class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)

self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)

self._conversation_name_generate_thread: Optional[Thread] = None

def process(
@@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
)

@@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)

with Session(db.engine) as session:
# Save message
@@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
@@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
@@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content

if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
@@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
@@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()

if trace_manager:
trace_manager.add_trace_task(
@@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)

extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata

self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=extras.get("metadata", {}),
metadata=metadata_dict,
)

def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

api/core/app/task_pipeline/message_cycle_manage.py → api/core/app/task_pipeline/message_cycle_manager.py View File

@@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
@@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService


class MessageCycleManage:
class MessageCycleManager:
def __init__(
self,
*,
@@ -45,7 +47,7 @@ class MessageCycleManage:
self._application_generate_entity = application_generate_entity
self._task_state = task_state

def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Generate conversation name.
:param conversation_id: conversation id
@@ -102,7 +104,7 @@ class MessageCycleManage:
db.session.commit()
db.session.close()

def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
"""
Handle annotation reply.
:param event: event
@@ -111,25 +113,28 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
self._task_state.metadata.annotation_reply = AnnotationReply(
id=annotation.id,
account=AnnotationReplyAccount(
id=annotation.account_id,
name=account.name if account else "Dify user",
),
)

return annotation

return None

def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
"""
Handle retriever resources.
:param event: event
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources
self._task_state.metadata.retriever_resources = event.retriever_resources

def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
@@ -166,7 +171,7 @@ class MessageCycleManage:

return None

def _message_to_stream_response(
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
"""
@@ -182,7 +187,7 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector,
)

def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer

+ 4
- 1
api/core/callback_handler/index_tool_callback_handler.py View File

@@ -1,8 +1,10 @@
import logging
from collections.abc import Sequence

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from extensions.ext_database import db
@@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:

db.session.commit()

def return_retriever_resource_info(self, resource: list):
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info."""
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER

+ 2
- 1
api/core/helper/code_executor/code_executor.py View File

@@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
from core.helper.code_executor.template_transformer import TemplateTransformer

logger = logging.getLogger(__name__)
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))


class CodeExecutionError(Exception):
@@ -64,7 +65,7 @@ class CodeExecutor:
:param code: code
:return:
"""
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"

headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}


+ 7
- 8
api/core/helper/marketplace.py View File

@@ -7,29 +7,28 @@ from configs import dify_config
from core.helper.download import download_with_size_limit
from core.plugin.entities.marketplace import MarketplacePluginDeclaration

marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))

def get_plugin_pkg_url(plugin_unique_identifier: str):
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
unique_identifier=plugin_unique_identifier
)

def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))


def download_plugin_pkg(plugin_unique_identifier: str):
url = str(get_plugin_pkg_url(plugin_unique_identifier))
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)


def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []

url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]


def record_install_plugin_event(plugin_unique_identifier: str):
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

+ 13
- 54
api/core/llm_generator/prompts.py View File

@@ -1,61 +1,20 @@
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
ENSURE your output is in the SAME language as the user's input!
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
Your output MUST be a valid JSON.
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.

Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
1. Detect Input Language
Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.).

2. Generate Title
- Combine Intention + Subject into a single, as-short-as-possible phrase.
- The title must be natural, friendly, and in the same language as the input.
- If the input is a direct question to the model, you may add an emoji at the end.

example 1:
User Input: hi, yesterday i had some burgers.
3. Output Format
Return **only** a valid JSON object with these exact keys and no additional text:
{
"Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "sharing yesterday's food"
}

example 2:
User Input: hello
{
"Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Greeting myself☺️"
}


example 3:
User Input: why mmap file: oom
{
"Language Type": "The user's input is written in pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Asking about the reason for mmap file: oom"
}


example 4:
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
{
"Language Type": "The user's input English-Chinese mixed",
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
}

example 5:
User Input: why小红的年龄is老than小明?
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Output": "询问小红和小明的年龄"
}

example 6:
User Input: yo, 你今天咋样?
{
"Language Type": "The user's input is English-Chinese mixed",
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
"Your Output": "查询今日我的状态☺️"
"Language Type": "<Detected language>",
"Your Reasoning": "<Brief explanation in that language>",
"Your Output": "<Intention + Subject>"
}

User Input:

+ 0
- 13
api/core/model_runtime/entities/llm_entities.py View File

@@ -17,19 +17,6 @@ class LLMMode(StrEnum):
COMPLETION = "completion"
CHAT = "chat"

@classmethod
def value_of(cls, value: str) -> "LLMMode":
"""
Get value of given mode.

:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")


class LLMUsage(ModelUsage):
"""

+ 12
- 11
api/core/model_runtime/utils/encoders.py View File

@@ -129,17 +129,18 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
# FIXME: mypy error, try to fix it instead of using type: ignore
obj_dict = dataclasses.asdict(obj) # type: ignore
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
# Ensure obj is a dataclass instance, not a dataclass type
if not isinstance(obj, type):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):

+ 39
- 0
api/core/ops/base_trace_instance.py View File

@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod

from sqlalchemy.orm import Session

from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.entities.trace_entity import BaseTraceInfo
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin


class BaseTraceInstance(ABC):
@@ -24,3 +28,38 @@ class BaseTraceInstance(ABC):
Subclasses must implement specific tracing logic for activities.
"""
...

def get_service_account_with_tenant(self, app_id: str) -> Account:
"""
Get service account for an app and set up its tenant.

Args:
app_id: The ID of the app

Returns:
Account: The service account with tenant set up

Raises:
ValueError: If app, creator account or tenant cannot be found
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)

return service_account

+ 8
- 5
api/core/ops/entities/trace_entity.py View File

@@ -3,7 +3,7 @@ from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator


class BaseTraceInfo(BaseModel):
@@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel):
return v
return ""

class Config:
json_encoders = {
datetime: lambda v: v.isoformat(),
}
model_config = ConfigDict(protected_namespaces=())

@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
return None
return dt.isoformat()


class WorkflowTraceInfo(BaseTraceInfo):

+ 7
- 18
api/core/ops/langfuse_trace/langfuse_trace.py View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import Optional

from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@@ -31,7 +31,7 @@ from core.ops.utils import filter_none_values
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, WorkflowNodeExecutionTriggeredFrom

logger = logging.getLogger(__name__)

@@ -114,22 +114,11 @@ class LangFuseDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

service_account = self.get_service_account_with_tenant(app_id)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,

+ 8
- 19
api/core/ops/langsmith_trace/langsmith_trace.py View File

@@ -6,7 +6,7 @@ from typing import Optional, cast

from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
@@ -28,10 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

logger = logging.getLogger(__name__)

@@ -139,22 +139,11 @@ class LangSmithDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
service_account = self.get_service_account_with_tenant(app_id)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
@@ -185,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)

execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
{

+ 8
- 19
api/core/ops/opik_trace/opik_trace.py View File

@@ -6,7 +6,7 @@ from typing import Optional, cast

from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
@@ -22,10 +22,10 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

logger = logging.getLogger(__name__)

@@ -154,22 +154,11 @@ class OpikDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
service_account = self.get_service_account_with_tenant(app_id)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
@@ -246,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance):
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id

if not total_tokens:
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0

span_data = {
"trace_id": opik_trace_id,

+ 3
- 2
api/core/ops/ops_trace_manager.py View File

@@ -30,7 +30,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.entities.workflow_execution import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@@ -386,7 +386,7 @@ class TraceTask:
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_execution.id if workflow_execution else None
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
@@ -487,6 +487,7 @@ class TraceTask:
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}

workflow_trace_info = WorkflowTraceInfo(

+ 8
- 19
api/core/ops/weave_trace/weave_trace.py View File

@@ -6,7 +6,7 @@ from typing import Any, Optional, cast

import wandb
import weave
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
@@ -23,10 +23,10 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

logger = logging.getLogger(__name__)

@@ -133,22 +133,11 @@ class WeaveDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
service_account = self.get_service_account_with_tenant(app_id)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
@@ -179,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)

execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
{

+ 3
- 4
api/core/plugin/impl/base.py View File

@@ -31,8 +31,7 @@ from core.plugin.impl.exc import (
PluginUniqueIdentifierError,
)

plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))

T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))

@@ -53,9 +52,9 @@ class BasePluginClient:
"""
Make a request to the plugin daemon inner API.
"""
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
url = plugin_daemon_inner_api_baseurl / path
headers = headers or {}
headers["X-Api-Key"] = plugin_daemon_inner_api_key
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
headers["Accept-Encoding"] = "gzip, deflate, br"

if headers.get("Content-Type") == "application/json" and isinstance(data, dict):

+ 0
- 1
api/core/rag/datasource/vdb/baidu/baidu_vector.py View File

@@ -85,7 +85,6 @@ class BaiduVector(BaseVector):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
# FIXME do you need this assert?
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),

+ 1
- 1
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py View File

@@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
if score > score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
docs.append(doc)

return docs


+ 4
- 0
api/core/rag/datasource/vdb/milvus/milvus_vector.py View File

@@ -97,6 +97,10 @@ class MilvusVector(BaseVector):

try:
milvus_version = self._client.get_server_version()
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True
# For standard Milvus installations, check version number
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")

+ 1
- 1
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py View File

@@ -245,4 +245,4 @@ class TidbService:
return cluster_infos
else:
response.raise_for_status()
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
return []

+ 23
- 0
api/core/rag/entities/citation_metadata.py View File

@@ -0,0 +1,23 @@
from typing import Any, Optional

from pydantic import BaseModel


class RetrievalSourceMetadata(BaseModel):
position: Optional[int] = None
dataset_id: Optional[str] = None
dataset_name: Optional[str] = None
document_id: Optional[str] = None
document_name: Optional[str] = None
data_source_type: Optional[str] = None
segment_id: Optional[str] = None
retriever_from: Optional[str] = None
score: Optional[float] = None
hit_count: Optional[int] = None
word_count: Optional[int] = None
segment_position: Optional[int] = None
index_node_hash: Optional[str] = None
content: Optional[str] = None
page: Optional[int] = None
doc_metadata: Optional[dict[str, Any]] = None
title: Optional[str] = None

+ 2
- 6
api/core/rag/extractor/entity/extract_setting.py View File

@@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel):
website import info.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

provider: str
job_id: str
url: str
@@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel):
tenant_id: str
only_main_content: bool = False

class Config:
arbitrary_types_allowed = True

def __init__(self, **data) -> None:
super().__init__(**data)


class ExtractSetting(BaseModel):
"""

+ 2
- 3
api/core/rag/models/document.py View File

@@ -70,13 +70,12 @@ class BaseDocumentTransformer(ABC):
.. code-block:: python

class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95

class Config:
arbitrary_types_allowed = True

def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:

+ 35
- 31
api/core/rag/retrieval/dataset_retrieval.py View File

@@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
@@ -198,21 +199,21 @@ class DatasetRetrieval:

dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = []
retrieval_resource_list = []
document_context_list: list[DocumentContext] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = {
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": invoke_from.to_source(),
"score": item.metadata.get("score"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=invoke_from.to_source(),
score=item.metadata.get("score"),
content=item.page_content,
)
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
@@ -248,32 +249,32 @@ class DatasetRetrieval:
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
)

if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position
item.position = position
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
@@ -936,6 +937,9 @@ class DatasetRetrieval:
return metadata_filter_document_ids, metadata_condition

def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
if not inputs:
return text

def replacer(match):
key = match.group(1)
return str(inputs.get(key, f"{{{{{key}}}}}"))

+ 22
- 8
api/core/repositories/sqlalchemy_workflow_execution_repository.py View File

@@ -10,12 +10,12 @@ from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker

from core.workflow.entities.workflow_execution_entities import (
from core.workflow.entities.workflow_execution import (
WorkflowExecution,
WorkflowExecutionStatus,
WorkflowType,
)
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from models import (
Account,
CreatorUserRole,
@@ -104,10 +104,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
status = WorkflowExecutionStatus(db_model.status)

return WorkflowExecution(
id=db_model.id,
id_=db_model.id,
workflow_id=db_model.workflow_id,
sequence_number=db_model.sequence_number,
type=WorkflowType(db_model.type),
workflow_type=WorkflowType(db_model.type),
workflow_version=db_model.version,
graph=graph,
inputs=inputs,
@@ -140,14 +139,29 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
raise ValueError("created_by_role is required in repository constructor")

db_model = WorkflowRun()
db_model.id = domain_model.id
db_model.id = domain_model.id_
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.sequence_number = domain_model.sequence_number
db_model.type = domain_model.type

# Check if this is a new record
with self._session_factory() as session:
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
if not existing:
# For new records, get the next sequence number
stmt = select(WorkflowRun.sequence_number).where(
WorkflowRun.app_id == self._app_id,
WorkflowRun.tenant_id == self._tenant_id,
)
max_sequence = session.scalar(stmt.order_by(WorkflowRun.sequence_number.desc()))
db_model.sequence_number = (max_sequence or 0) + 1
else:
# For updates, keep the existing sequence number
db_model.sequence_number = existing.sequence_number

db_model.type = domain_model.workflow_type
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None

+ 38
- 39
api/core/repositories/sqlalchemy_workflow_node_execution_repository.py View File

@@ -12,19 +12,18 @@ from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker

from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
)

@@ -87,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)

# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}

def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
"""
Convert a database model to a domain model.

@@ -103,16 +102,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
inputs = db_model.inputs_dict
process_data = db_model.process_data_dict
outputs = db_model.outputs_dict
metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}

# Convert status to domain enum
status = NodeExecutionStatus(db_model.status)
status = WorkflowNodeExecutionStatus(db_model.status)

return NodeExecution(
return WorkflowNodeExecution(
id=db_model.id,
node_execution_id=db_model.node_execution_id,
workflow_id=db_model.workflow_id,
workflow_run_id=db_model.workflow_run_id,
workflow_execution_id=db_model.workflow_run_id,
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
@@ -129,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
finished_at=db_model.finished_at,
)

def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
"""
Convert a domain model to a database model.

@@ -147,14 +146,14 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")

db_model = WorkflowNodeExecution()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.workflow_run_id = domain_model.workflow_run_id
db_model.workflow_run_id = domain_model.workflow_execution_id
db_model.index = domain_model.index
db_model.predecessor_node_id = domain_model.predecessor_node_id
db_model.node_execution_id = domain_model.node_execution_id
@@ -176,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.finished_at = domain_model.finished_at
return db_model

def save(self, execution: NodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to the database.

@@ -208,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model

def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.

@@ -231,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
)

if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)

db_model = session.scalar(stmt)
if db_model:
@@ -254,7 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]:
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.

@@ -272,20 +271,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of WorkflowNodeExecution database models
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == triggered_from,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
)

if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)

# Apply ordering if provided
if order_config and order_config.order_by:
order_columns: list[UnaryExpression] = []
for field in order_config.order_by:
column = getattr(WorkflowNodeExecution, field, None)
column = getattr(WorkflowNodeExecutionModel, field, None)
if not column:
continue
if order_config.order_direction == "desc":
@@ -310,7 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[NodeExecution]:
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.

@@ -337,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)

return domain_models

def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.

@@ -351,15 +350,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)

db_models = session.scalars(stmt).all()
domain_models = []
@@ -384,10 +383,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)

if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)

result = session.execute(stmt)
session.commit()

+ 1
- 1
api/core/tools/custom_tool/tool.py View File

@@ -168,7 +168,7 @@ class ApiTool(Tool):
cookies[parameter["name"]] = value

elif parameter["in"] == "header":
headers[parameter["name"]] = value
headers[parameter["name"]] = str(value)

# check if there is a request body and handle it
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:

+ 0
- 1
api/core/tools/entities/tool_entities.py View File

@@ -279,7 +279,6 @@ class ToolParameter(PluginParameter):
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))

+ 20
- 19
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py View File

@@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
@@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
.first()
)
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
position=resource_number,
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id, None),
doc_metadata=document.doc_metadata,
)

if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
context_list.append(source)
resource_number += 1


+ 37
- 36
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field

from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@@ -14,7 +15,7 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService

default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
document_ids_filter = None
if dataset.provider == "external":
results = []
results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
position=position,
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=self.retriever_from,
score=item.metadata.get("score"),
title=item.metadata.get("title"),
content=item.page_content,
)
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
@@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id, # type: ignore
"document_name": document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
)

if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source)

if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.get("score") or 0.0,
key=lambda x: x.score or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore
item.position = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:

+ 0
- 1
api/core/tools/utils/message_transformer.py View File

@@ -66,7 +66,6 @@ class ToolFileMessageTransformer:
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")

# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw(

+ 7
- 0
api/core/tools/utils/parser.py View File

@@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for i, parameter in enumerate(interface["operation"]["parameters"]):
if "$ref" in parameter:
root = openapi
reference = parameter["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],

+ 2
- 29
api/core/workflow/entities/node_entities.py View File

@@ -1,37 +1,10 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional

from pydantic import BaseModel

from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus


class NodeRunMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""

TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
DATASOURCE_INFO = "datasource_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus


class NodeRunResult(BaseModel):
@@ -44,7 +17,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage

edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

api/core/workflow/entities/workflow_execution_entities.py → api/core/workflow/entities/workflow_execution.py View File

@@ -37,12 +37,10 @@ class WorkflowExecution(BaseModel):
user, tenant, and app attributes.
"""

id: str = Field(...)
id_: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
sequence_number: int = Field(...)

type: WorkflowType = Field(...)
workflow_type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)

inputs: Mapping[str, Any] = Field(...)
@@ -70,20 +68,18 @@ class WorkflowExecution(BaseModel):
def new(
cls,
*,
id: str,
id_: str,
workflow_id: str,
sequence_number: int,
type: WorkflowType,
workflow_type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
return WorkflowExecution(
id=id,
id_=id_,
workflow_id=workflow_id,
sequence_number=sequence_number,
type=type,
workflow_type=workflow_type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,

api/core/workflow/entities/node_execution_entities.py → api/core/workflow/entities/workflow_node_execution.py View File

@@ -13,11 +13,35 @@ from typing import Any, Optional

from pydantic import BaseModel, Field

from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType


class NodeExecutionStatus(StrEnum):
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""

TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output


class WorkflowNodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
@@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum):
RETRY = "retry"


class NodeExecution(BaseModel):
class WorkflowNodeExecution(BaseModel):
"""
Domain model for workflow node execution.

@@ -46,7 +70,7 @@ class NodeExecution(BaseModel):
id: str # Unique identifier for this execution record
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
workflow_id: str # ID of the workflow this node belongs to
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)

# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
@@ -61,12 +85,12 @@ class NodeExecution(BaseModel):
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node

# Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds

# Additional metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)

# Timing information
created_at: datetime # When execution started
@@ -77,7 +101,7 @@ class NodeExecution(BaseModel):
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
) -> None:
"""
Update the model from mappings.

+ 1
- 1
api/core/workflow/enums.py View File

@@ -13,7 +13,7 @@ class SystemVariableKey(StrEnum):
DIALOGUE_COUNT = "dialogue_count"
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_RUN_ID = "workflow_run_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
BATCH = "batch"

+ 3
- 2
api/core/workflow/graph_engine/entities/event.py View File

@@ -1,9 +1,10 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional

from pydantic import BaseModel, Field

from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
@@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):


class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")



+ 1
- 1
api/core/workflow/graph_engine/entities/runtime_route_state.py View File

@@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, Field

from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus


class RouteNodeState(BaseModel):

+ 17
- 16
api/core/workflow/graph_engine/graph_engine.py View File

@@ -14,8 +14,9 @@ from flask import Flask, current_app, has_request_context
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseAgentEvent,
@@ -52,9 +53,8 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType

logger = logging.getLogger(__name__)

@@ -606,8 +606,6 @@ class GraphEngine:
error=str(e),
)
)
finally:
db.session.remove()

def _run_node(
self,
@@ -645,7 +643,6 @@ class GraphEngine:
agent_strategy=agent_strategy,
)

db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
@@ -759,10 +756,12 @@ class GraphEngine:
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
if run_result.metadata and run_result.metadata.get(
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS
):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)

if run_result.llm_usage:
@@ -785,13 +784,17 @@ class GraphEngine:

if parallel_id and parallel_start_node_id:
metadata_dict = dict(run_result.metadata)
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = (
parent_parallel_id
)
metadata_dict[
WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID
] = parent_parallel_start_node_id
run_result.metadata = metadata_dict

yield NodeRunSucceededEvent(
@@ -856,8 +859,6 @@ class GraphEngine:
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()

def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
@@ -923,7 +924,7 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}


+ 9
- 9
api/core/workflow/nodes/agent/agent_node.py View File

@@ -2,6 +2,9 @@ import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData
@@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus


class AgentNode(ToolNode):
@@ -320,15 +323,12 @@ class AgentNode(ToolNode):
return None
conversation_id = conversation_id_variable.value

# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)

if not conversation:
return None
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)


+ 1
- 1
api/core/workflow/nodes/answer/answer_node.py View File

@@ -3,6 +3,7 @@ from typing import Any, cast

from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
@@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import (
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus


class AnswerNode(BaseNode[AnswerNodeData]):

+ 1
- 1
api/core/workflow/nodes/base/node.py View File

@@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus

from .entities import BaseNodeData


+ 1
- 1
api/core/workflow/nodes/code/code_node.py View File

@@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus

from .exc import (
CodeNodeError,

+ 1
- 1
api/core/workflow/nodes/document_extractor/node.py View File

@@ -26,9 +26,9 @@ from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus

from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError

+ 1
- 1
api/core/workflow/nodes/end/end_node.py View File

@@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus


class EndNode(BaseNode[EndNodeData]):

+ 4
- 2
api/core/workflow/nodes/event/event.py View File

@@ -1,10 +1,12 @@
from collections.abc import Sequence
from datetime import datetime

from pydantic import BaseModel, Field

from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus


class RunCompletedEvent(BaseModel):
@@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):


class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")



+ 1
- 1
api/core/workflow/nodes/http_request/node.py View File

@@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
from models.workflow import WorkflowNodeExecutionStatus

from .entities import (
HttpRequestNodeData,

+ 1
- 1
api/core/workflow/nodes/if_else/if_else_node.py View File

@@ -4,12 +4,12 @@ from typing_extensions import deprecated

from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus


class IfElseNode(BaseNode[IfElseNodeData]):

+ 7
- 8
api/core/workflow/nodes/iteration/iteration_node.py View File

@@ -12,10 +12,10 @@ from flask import Flask, current_app, has_request_context
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@@ -37,7 +37,6 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus

from .exc import (
InvalidIteratorValueError,
@@ -249,8 +248,8 @@ class IterationNode(BaseNode[IterationNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
},
)
)
@@ -361,16 +360,16 @@ class IterationNode(BaseNode[IterationNodeData]):
event.parallel_mode_run_id = parallel_mode_run_id

iter_metadata = {
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id

if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}

return event

+ 1
- 1
api/core/workflow/nodes/iteration/iteration_start_node.py View File

@@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus


class IterationStartNode(BaseNode[IterationStartNodeData]):

+ 16
- 19
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py View File

@@ -8,6 +8,7 @@ from typing import Any, Optional, cast

from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@@ -41,7 +43,6 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService

from .entities import KnowledgeRetrievalNodeData, ModelConfig
@@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
@@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
model_instance, model_config = self.get_model_config(metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
@@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping

def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
"""
Fetch model config
:param model: model
:return:
"""
if model is None:
raise ValueError("model is required")
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider


+ 1
- 1
api/core/workflow/nodes/list_operator/node.py View File

@@ -4,9 +4,9 @@ from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus

from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError

+ 51
- 50
api/core/workflow/nodes/llm/node.py View File

@@ -7,6 +7,8 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast

import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session

from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -43,6 +45,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
@@ -53,9 +56,10 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
@@ -77,7 +81,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus

from .entities import (
LLMNodeChatModelMessage,
@@ -267,9 +270,9 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@@ -302,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
@@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource = []
original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip()
)

def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
def _convert_to_original_retriever_resource(self, context_dict: dict):
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
):
metadata = context_dict.get("metadata", {})

source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
)

return source

@@ -602,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
conversation_id = conversation_id_variable.value

# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)

if not conversation:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

@@ -846,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1

if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

@classmethod
def _extract_variable_selector_to_variable_mapping(

+ 1
- 1
api/core/workflow/nodes/loop/loop_end_node.py View File

@@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
from models.workflow import WorkflowNodeExecutionStatus


class LoopEndNode(BaseNode[LoopEndNodeData]):

+ 28
- 20
api/core/workflow/nodes/loop/loop_node.py View File

@@ -15,7 +15,8 @@ from core.variables import (
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus

if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]):
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)

@@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
inputs=inputs,
@@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
error=str(e),
)
@@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
@@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]):
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
),
"completed_reason": "error",
},
error=event.error,
@@ -331,7 +333,11 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
)
},
)
)
return {"check_break_result": True}
@@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]):
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
@@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return {"check_break_result": True}
@@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.LOOP_ID not in metadata:
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.LOOP_ID: self.node_id,
NodeRunMetadataKey.LOOP_INDEX: iter_run_index,
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event

+ 1
- 1
api/core/workflow/nodes/loop/loop_start_node.py View File

@@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
from models.workflow import WorkflowNodeExecutionStatus


class LoopStartNode(BaseNode[LoopStartNodeData]):

+ 5
- 9
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py View File

@@ -25,13 +25,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus

from .entities import ParameterExtractorNodeData
from .exc import (
@@ -244,9 +243,9 @@ class ParameterExtractorNode(LLMNode):
process_data=process_data,
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@@ -816,7 +813,6 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}

if node_data.instruction:

+ 13
- 9
api/core/workflow/nodes/question_classifier/question_classifier_node.py View File

@@ -10,7 +10,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
@@ -20,7 +21,6 @@ from core.workflow.nodes.llm import (
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus

from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
@@ -79,9 +79,13 @@ class QuestionClassifierNode(LLMNode):
memory=memory,
max_token_limit=rest_token,
)
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
sys_query="",
memory=memory,
model_config=model_config,
sys_files=files,
@@ -142,9 +146,9 @@ class QuestionClassifierNode(LLMNode):
outputs=outputs,
edge_source_handle=category_id,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@@ -154,9 +158,9 @@ class QuestionClassifierNode(LLMNode):
inputs=variables,
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)

+ 1
- 1
api/core/workflow/nodes/start/start_node.py View File

@@ -1,9 +1,9 @@
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus


class StartNode(BaseNode[StartNodeData]):

+ 1
- 1
api/core/workflow/nodes/template_transform/template_transform_node.py View File

@@ -4,10 +4,10 @@ from typing import Any, Optional

from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus

MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))


+ 9
- 9
api/core/workflow/nodes/tool/tool_node.py View File

@@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
@@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

from .entities import ToolNodeData
@@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
@@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
@@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
)
@@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = []

agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}

variables: dict[str, Any] = {}

@@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in NodeRunMetadataKey.__members__.values()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]):
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)

+ 3
- 2
api/core/workflow/nodes/variable_aggregator/entities.py View File

@@ -1,7 +1,8 @@
from typing import Literal, Optional
from typing import Optional

from pydantic import BaseModel

from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData


@@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel):
Group.
"""

output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
output_type: SegmentType
variables: list[list[str]]
group_name: str


+ 1
- 1
api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py View File

@@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus


class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):

+ 1
- 1
api/core/workflow/nodes/variable_assigner/v1/node.py View File

@@ -1,11 +1,11 @@
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models.workflow import WorkflowNodeExecutionStatus

from .node_data import VariableAssignerData, WriteMode


+ 1
- 1
api/core/workflow/nodes/variable_assigner/v2/node.py View File

@@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus

from . import helpers
from .constants import EMPTY_VALUE_MAPPING

api/core/workflow/repository/__init__.py → api/core/workflow/repositories/__init__.py View File

@@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""

from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository

__all__ = [
"OrderConfig",

api/core/workflow/repository/workflow_execution_repository.py → api/core/workflow/repositories/workflow_execution_repository.py View File

@@ -1,6 +1,6 @@
from typing import Optional, Protocol

from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.entities.workflow_execution import WorkflowExecution


class WorkflowExecutionRepository(Protocol):

api/core/workflow/repository/workflow_node_execution_repository.py → api/core/workflow/repositories/workflow_node_execution_repository.py View File

@@ -2,7 +2,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol

from core.workflow.entities.node_execution_entities import NodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution


@dataclass
@@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""

def save(self, execution: NodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution instance.

@@ -39,7 +39,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...

def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.

@@ -55,7 +55,7 @@ class WorkflowNodeExecutionRepository(Protocol):
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[NodeExecution]:
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.

@@ -70,7 +70,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...

def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.


+ 48
- 63
api/core/workflow/workflow_cycle_manager.py View File

@@ -1,11 +1,9 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Optional, Union
from uuid import uuid4

from sqlalchemy import func, select
from sqlalchemy.orm import Session

from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
@@ -19,21 +17,24 @@ from core.app.entities.queue_entities import (
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from models import (
Workflow,
WorkflowRun,
WorkflowRunStatus,
)


@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]


class WorkflowCycleManager:
@@ -42,32 +43,17 @@ class WorkflowCycleManager:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository

def handle_workflow_run_start(
self,
*,
session: Session,
workflow_id: str,
) -> WorkflowExecution:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")

max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
)
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1

def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
@@ -79,14 +65,13 @@ class WorkflowCycleManager:

# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
execution = WorkflowExecution.new(
id=execution_id,
workflow_id=workflow.id,
sequence_number=new_sequence_number,
type=WorkflowType(workflow.type),
workflow_version=workflow.version,
graph=workflow.graph_dict,
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=datetime.now(UTC).replace(tzinfo=None),
)
@@ -168,7 +153,7 @@ class WorkflowCycleManager:
workflow_run_id: str,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
@@ -185,7 +170,7 @@ class WorkflowCycleManager:

# Use the instance repository to find running executions for a workflow run
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_execution.id
workflow_run_id=workflow_execution.id_
)

# Update the domain models
@@ -193,7 +178,7 @@ class WorkflowCycleManager:
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model
node_execution.status = NodeExecutionStatus.FAILED
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
@@ -219,28 +204,28 @@ class WorkflowCycleManager:
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)

# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}

domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RUNNING,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=created_at,
)
@@ -250,7 +235,7 @@ class WorkflowCycleManager:

return domain_execution

def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
@@ -271,7 +256,7 @@ class WorkflowCycleManager:
elapsed_time = (finished_at - event.start_at).total_seconds()

# Update domain model
domain_execution.status = NodeExecutionStatus.SUCCEEDED
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
@@ -290,7 +275,7 @@ class WorkflowCycleManager:
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@@ -317,9 +302,9 @@ class WorkflowCycleManager:

# Update domain model
domain_execution.status = (
NodeExecutionStatus.FAILED
WorkflowNodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else NodeExecutionStatus.EXCEPTION
else WorkflowNodeExecutionStatus.EXCEPTION
)
domain_execution.error = event.error
domain_execution.update_from_mapping(
@@ -335,7 +320,7 @@ class WorkflowCycleManager:

def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
@@ -345,13 +330,13 @@ class WorkflowCycleManager:

# Convert metadata keys to strings
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}

# Convert execution metadata keys to strings
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
@@ -359,16 +344,16 @@ class WorkflowCycleManager:
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

# Create a domain model
domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RETRY,
status=WorkflowNodeExecutionStatus.RETRY,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,

+ 2
- 2
api/factories/variable_factory.py View File

@@ -93,8 +93,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError("missing value type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
# FIXME: using Any here, fix it later
result: Any
result: Variable
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)

+ 2
- 1
api/libs/smtp.py View File

@@ -28,7 +28,8 @@ class SMTPClient:
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)

if self.username and self.password:
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
smtp.login(self.username, self.password)

msg = MIMEMultipart()

+ 4
- 8
api/models/__init__.py View File

@@ -85,11 +85,9 @@ from .workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowType,
)

@@ -101,14 +99,14 @@ __all__ = [
"AccountStatus",
"ApiRequest",
"ApiToken",
"ApiToolProvider", # Added
"ApiToolProvider",
"App",
"AppAnnotationHitHistory",
"AppAnnotationSetting",
"AppDatasetJoin",
"AppMode",
"AppModelConfig",
"BuiltinToolProvider", # Added
"BuiltinToolProvider",
"CeleryTask",
"CeleryTaskSet",
"Conversation",
@@ -174,11 +172,9 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecution",
"WorkflowNodeExecutionStatus",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun",
"WorkflowRunStatus",
"WorkflowRunTriggeredFrom",
"WorkflowToolProvider",
"WorkflowType",

+ 10
- 10
api/models/model.py View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import sign_tool_file
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from services.plugin.plugin_service import PluginService

if TYPE_CHECKING:
@@ -31,7 +32,6 @@ from .base import Base
from .engine import db
from .enums import CreatorUserRole
from .types import StringUUID
from .workflow import WorkflowRunStatus

if TYPE_CHECKING:
from .workflow import Workflow
@@ -795,22 +795,22 @@ class Conversation(Base):
def status_count(self):
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
status_counts = {
WorkflowRunStatus.RUNNING: 0,
WorkflowRunStatus.SUCCEEDED: 0,
WorkflowRunStatus.FAILED: 0,
WorkflowRunStatus.STOPPED: 0,
WorkflowRunStatus.PARTIAL_SUCCEEDED: 0,
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
WorkflowExecutionStatus.FAILED: 0,
WorkflowExecutionStatus.STOPPED: 0,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
}

for message in messages:
if message.workflow_run:
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1

return (
{
"success": status_counts[WorkflowRunStatus.SUCCEEDED],
"failed": status_counts[WorkflowRunStatus.FAILED],
"partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCEEDED],
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
}
if messages
else None

+ 3
- 27
api/models/workflow.py View File

@@ -401,18 +401,6 @@ class Workflow(Base):
)


class WorkflowRunStatus(StrEnum):
"""
Workflow Run Status Enum
"""

RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"


class WorkflowRun(Base):
"""
Workflow Run
@@ -473,12 +461,12 @@ class WorkflowRun(Base):
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)

@property
def created_by_account(self):
@@ -578,19 +566,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
RAG_PIPELINE_RUN = "rag-pipeline-run"


class WorkflowNodeExecutionStatus(StrEnum):
"""
Workflow Node Execution Status Enum
"""

RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"


class WorkflowNodeExecution(Base):
class WorkflowNodeExecutionModel(Base):
"""
Workflow Node Execution


+ 9
- 7
api/pyproject.toml View File

@@ -14,7 +14,7 @@ dependencies = [
"chardet~=5.1.0",
"flask~=3.1.0",
"flask-compress~=1.17",
"flask-cors~=5.0.0",
"flask-cors~=6.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
"flask-restful~=0.3.10",
@@ -36,7 +36,6 @@ dependencies = [
"mailchimp-transactional~=1.0.50",
"markdown~=3.5.1",
"numpy~=1.26.4",
"oci~=2.135.1",
"openai~=1.61.0",
"openpyxl~=3.1.5",
"opik~=1.7.25",
@@ -143,13 +142,16 @@ dev = [
"types-requests~=2.32.0",
"types-requests-oauthlib~=2.0.0",
"types-shapely~=2.0.0",
"types-simplejson~=3.20.0",
"types-six~=1.17.0",
"types-tensorflow~=2.18.0",
"types-tqdm~=4.67.0",
"types-ujson~=5.10.0",
"types-simplejson>=3.20.0",
"types-six>=1.17.0",
"types-tensorflow>=2.18.0",
"types-tqdm>=4.67.0",
"types-ujson>=5.10.0",
"boto3-stubs>=1.38.20",
"types-jmespath>=1.0.2.20240106",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
]

############################################################

+ 0
- 1
api/pytest.ini View File

@@ -1,5 +1,4 @@
[pytest]
continue-on-collection-errors = true
addopts = --cov=./api --cov-report=json --cov-report=xml
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz

+ 1
- 2
api/schedule/clean_messages.py View File

@@ -34,9 +34,8 @@ def clean_messages():
while True:
try:
# Main query with join and filter
# FIXME:for mypy no paginate method error
messages = (
db.session.query(Message) # type: ignore
db.session.query(Message)
.filter(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)

+ 7
- 6
api/services/clear_free_plan_tenant_expired_logs.py View File

@@ -14,7 +14,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
from models.workflow import WorkflowNodeExecution, WorkflowRun
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
from services.billing_service import BillingService

logger = logging.getLogger(__name__)
@@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with Session(db.engine).no_autoflush as session:
workflow_node_executions = (
session.query(WorkflowNodeExecution)
session.query(WorkflowNodeExecutionModel)
.filter(
WorkflowNodeExecution.tenant_id == tenant_id,
WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at
< datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
@@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs:
]

# delete workflow node executions
session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id.in_(workflow_node_execution_ids),
session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
).delete(synchronize_session=False)
session.commit()


+ 28
- 2
api/services/hit_testing_service.py View File

@@ -2,8 +2,11 @@ import logging
import time
from typing import Any

from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
@@ -34,7 +37,29 @@ class HitTestingService:
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model

document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
dataset_retrieval = DatasetRetrieval()

from core.app.app_config.entities import MetadataFilteringCondition

metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)

metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
query=query,
metadata_filtering_mode="manual",
metadata_filtering_conditions=metadata_filtering_conditions,
inputs={},
tenant_id="",
user_id="",
metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}),
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
@@ -48,6 +73,7 @@ class HitTestingService:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)

end = time.perf_counter()
@@ -99,7 +125,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))

@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]):
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
records = RetrievalService.format_retrieval_documents(documents)

return {

+ 8
- 8
api/services/ops_service.py View File

@@ -1,5 +1,6 @@
from typing import Optional
from typing import Any, Optional

from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
@@ -92,13 +93,12 @@ class OpsService:
except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"}

config_class, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
# FIXME: ignore type error
default_config_instance = config_class(**tracing_config) # type: ignore
for key in other_keys: # type: ignore
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]

default_config_instance: BaseTracingConfig = config_class(**tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)


+ 0
- 0
api/services/tag_service.py View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save