Bläddra i källkod

refactor: Replaces direct DB session usage with context managers (#20569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.2
-LAN- 5 månader sedan
förälder
incheckning
72fdafc180
Inget konto är kopplat till bidragsgivarens mejladress

+ 0
- 6
api/core/workflow/graph_engine/graph_engine.py Visa fil

@@ -53,7 +53,6 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType

@@ -607,8 +606,6 @@ class GraphEngine:
error=str(e),
)
)
finally:
db.session.remove()

def _run_node(
self,
@@ -646,7 +643,6 @@ class GraphEngine:
agent_strategy=agent_strategy,
)

db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
@@ -863,8 +859,6 @@ class GraphEngine:
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()

def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

+ 8
- 8
api/core/workflow/nodes/agent/agent_node.py Visa fil

@@ -2,6 +2,9 @@ import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -320,15 +323,12 @@ class AgentNode(ToolNode):
return None
conversation_id = conversation_id_variable.value

# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)

if not conversation:
return None
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)


+ 10
- 8
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py Visa fil

@@ -8,6 +8,7 @@ from typing import Any, Optional, cast

from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

+ 25
- 25
api/core/workflow/nodes/llm/node.py Visa fil

@@ -7,6 +7,8 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast

import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session

from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -303,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
@@ -603,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
conversation_id = conversation_id_variable.value

# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)

if not conversation:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

@@ -847,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1

if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

@classmethod
def _extract_variable_selector_to_variable_mapping(

+ 0
- 3
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Visa fil

@@ -31,7 +31,6 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser
from extensions.ext_database import db

from .entities import ParameterExtractorNodeData
from .exc import (
@@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,

Laddar…
Avbryt
Spara