Browse Source

refactor: Replaces direct DB session usage with context managers (#20569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.2
-LAN- 5 months ago
parent
commit
72fdafc180
No account linked to committer's email address

+ 0
- 6
api/core/workflow/graph_engine/graph_engine.py View File

from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType


error=str(e), error=str(e),
) )
) )
finally:
db.session.remove()


def _run_node( def _run_node(
self, self,
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
) )


db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0 retries = 0
except Exception as e: except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed") logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e raise e
finally:
db.session.close()


def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
""" """

+ 8
- 8
api/core/workflow/nodes/agent/agent_node.py View File

from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast


from sqlalchemy import select
from sqlalchemy.orm import Session

from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
return None return None
conversation_id = conversation_id_variable.value conversation_id = conversation_id_variable.value


# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)


if not conversation:
return None
if not conversation:
return None


memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)



+ 10
- 8
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py View File



from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session


from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
redis_client.zremrangebyscore(key, 0, current_time - 60000) redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key) request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit: if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables, inputs=variables,

+ 25
- 25
api/core/workflow/nodes/llm/node.py View File

from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast


import json_repair import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session


from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close()

invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
return None return None
conversation_id = conversation_id_variable.value conversation_id = conversation_id_variable.value


# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)

if not conversation:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None


memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)


used_quota = 1 used_quota = 1


if used_quota is not None and system_configuration.current_quota_type is not None: if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()


@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(

+ 0
- 3
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py View File

from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from extensions.ext_database import db


from .entities import ParameterExtractorNodeData from .entities import ParameterExtractorNodeData
from .exc import ( from .exc import (
tools: list[PromptMessageTool], tools: list[PromptMessageTool],
stop: list[str], stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()

invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,

Loading…
Cancel
Save