| @@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner): | |||
| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): | |||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| db.session.refresh(conversation) | |||
| db.session.refresh(message) | |||
| db.session.close() | |||
| # start agent runner | |||
| if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | |||
| assistant_cot_runner = AssistantCotApplicationRunner( | |||
| @@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner): | |||
| model=app_orchestration_config.model_config.model | |||
| ) | |||
| db.session.close() | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=app_orchestration_config.model_config.parameters, | |||
| @@ -89,6 +89,10 @@ class GenerateTaskPipeline: | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| db.session.refresh(self._conversation) | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| if stream: | |||
| return self._process_stream_response() | |||
| else: | |||
| @@ -303,6 +307,7 @@ class GenerateTaskPipeline: | |||
| .first() | |||
| ) | |||
| db.session.refresh(agent_thought) | |||
| db.session.close() | |||
| if agent_thought: | |||
| response = { | |||
| @@ -330,6 +335,8 @@ class GenerateTaskPipeline: | |||
| .filter(MessageFile.id == event.message_file_id) | |||
| .first() | |||
| ) | |||
| db.session.close() | |||
| # get extension | |||
| if '.' in message_file.url: | |||
| extension = f'.{message_file.url.split(".")[-1]}' | |||
| @@ -413,6 +420,7 @@ class GenerateTaskPipeline: | |||
| usage = llm_result.usage | |||
| self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() | |||
| self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) | |||
| self._message.message_tokens = usage.prompt_tokens | |||
| @@ -201,7 +201,7 @@ class ApplicationManager: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| finally: | |||
| db.session.remove() | |||
| db.session.close() | |||
| def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| @@ -233,8 +233,6 @@ class ApplicationManager: | |||
| else: | |||
| logger.exception(e) | |||
| raise e | |||
| finally: | |||
| db.session.remove() | |||
| def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ | |||
| -> AppOrchestrationConfigEntity: | |||
| @@ -651,6 +649,7 @@ class ApplicationManager: | |||
| db.session.add(conversation) | |||
| db.session.commit() | |||
| db.session.refresh(conversation) | |||
| else: | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| @@ -689,6 +688,7 @@ class ApplicationManager: | |||
| db.session.add(message) | |||
| db.session.commit() | |||
| db.session.refresh(message) | |||
| for file in application_generate_entity.files: | |||
| message_file = MessageFile( | |||
| @@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| self.agent_thought_count = db.session.query(MessageAgentThought).filter( | |||
| MessageAgentThought.message_id == self.message.id, | |||
| ).count() | |||
| db.session.close() | |||
| # check if model supports stream tool call | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| @@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| created_by=self.user_id, | |||
| ) | |||
| db.session.add(message_file) | |||
| db.session.commit() | |||
| db.session.refresh(message_file) | |||
| result.append(( | |||
| message_file, | |||
| message.save_as | |||
| )) | |||
| db.session.commit() | |||
| db.session.close() | |||
| return result | |||
| def create_agent_thought(self, message_id: str, message: str, | |||
| @@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| db.session.add(thought) | |||
| db.session.commit() | |||
| db.session.refresh(thought) | |||
| db.session.close() | |||
| self.agent_thought_count += 1 | |||
| @@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| """ | |||
| Save agent thought | |||
| """ | |||
| agent_thought = db.session.query(MessageAgentThought).filter( | |||
| MessageAgentThought.id == agent_thought.id | |||
| ).first() | |||
| if thought is not None: | |||
| agent_thought.thought = thought | |||
| @@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| agent_thought.tool_labels_str = json.dumps(labels) | |||
| db.session.commit() | |||
| db.session.close() | |||
| def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: | |||
| """ | |||
| @@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| """ | |||
| convert tool variables to db variables | |||
| """ | |||
| db_variables = db.session.query(ToolConversationVariables).filter( | |||
| ToolConversationVariables.conversation_id == self.message.conversation_id, | |||
| ).first() | |||
| db_variables.updated_at = datetime.utcnow() | |||
| db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) | |||
| db.session.commit() | |||
| db.session.close() | |||
| def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||
| """ | |||
| @@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| if message.answer: | |||
| result.append(AssistantPromptMessage(content=message.answer)) | |||
| db.session.close() | |||
| return result | |||