| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): | if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): | ||||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | ||||
| db.session.refresh(conversation) | |||||
| db.session.refresh(message) | |||||
| db.session.close() | |||||
| # start agent runner | # start agent runner | ||||
| if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | ||||
| assistant_cot_runner = AssistantCotApplicationRunner( | assistant_cot_runner = AssistantCotApplicationRunner( |
| model=app_orchestration_config.model_config.model | model=app_orchestration_config.model_config.model | ||||
| ) | ) | ||||
| db.session.close() | |||||
| invoke_result = model_instance.invoke_llm( | invoke_result = model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| model_parameters=app_orchestration_config.model_config.parameters, | model_parameters=app_orchestration_config.model_config.parameters, |
| Process generate task pipeline. | Process generate task pipeline. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| db.session.refresh(self._conversation) | |||||
| db.session.refresh(self._message) | |||||
| db.session.close() | |||||
| if stream: | if stream: | ||||
| return self._process_stream_response() | return self._process_stream_response() | ||||
| else: | else: | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| db.session.refresh(agent_thought) | db.session.refresh(agent_thought) | ||||
| db.session.close() | |||||
| if agent_thought: | if agent_thought: | ||||
| response = { | response = { | ||||
| .filter(MessageFile.id == event.message_file_id) | .filter(MessageFile.id == event.message_file_id) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| db.session.close() | |||||
| # get extension | # get extension | ||||
| if '.' in message_file.url: | if '.' in message_file.url: | ||||
| extension = f'.{message_file.url.split(".")[-1]}' | extension = f'.{message_file.url.split(".")[-1]}' | ||||
| usage = llm_result.usage | usage = llm_result.usage | ||||
| self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | ||||
| self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() | |||||
| self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) | self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) | ||||
| self._message.message_tokens = usage.prompt_tokens | self._message.message_tokens = usage.prompt_tokens |
| logger.exception("Unknown Error when generating") | logger.exception("Unknown Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| finally: | finally: | ||||
| db.session.remove() | |||||
| db.session.close() | |||||
| def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, | def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, | ||||
| queue_manager: ApplicationQueueManager, | queue_manager: ApplicationQueueManager, | ||||
| else: | else: | ||||
| logger.exception(e) | logger.exception(e) | ||||
| raise e | raise e | ||||
| finally: | |||||
| db.session.remove() | |||||
| def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ | def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ | ||||
| -> AppOrchestrationConfigEntity: | -> AppOrchestrationConfigEntity: | ||||
| db.session.add(conversation) | db.session.add(conversation) | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.refresh(conversation) | |||||
| else: | else: | ||||
| conversation = ( | conversation = ( | ||||
| db.session.query(Conversation) | db.session.query(Conversation) | ||||
| db.session.add(message) | db.session.add(message) | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.refresh(message) | |||||
| for file in application_generate_entity.files: | for file in application_generate_entity.files: | ||||
| message_file = MessageFile( | message_file = MessageFile( |
| self.agent_thought_count = db.session.query(MessageAgentThought).filter( | self.agent_thought_count = db.session.query(MessageAgentThought).filter( | ||||
| MessageAgentThought.message_id == self.message.id, | MessageAgentThought.message_id == self.message.id, | ||||
| ).count() | ).count() | ||||
| db.session.close() | |||||
| # check if model supports stream tool call | # check if model supports stream tool call | ||||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | ||||
| created_by=self.user_id, | created_by=self.user_id, | ||||
| ) | ) | ||||
| db.session.add(message_file) | db.session.add(message_file) | ||||
| db.session.commit() | |||||
| db.session.refresh(message_file) | |||||
| result.append(( | result.append(( | ||||
| message_file, | message_file, | ||||
| message.save_as | message.save_as | ||||
| )) | )) | ||||
| db.session.commit() | |||||
| db.session.close() | |||||
| return result | return result | ||||
| def create_agent_thought(self, message_id: str, message: str, | def create_agent_thought(self, message_id: str, message: str, | ||||
| db.session.add(thought) | db.session.add(thought) | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.refresh(thought) | |||||
| db.session.close() | |||||
| self.agent_thought_count += 1 | self.agent_thought_count += 1 | ||||
| """ | """ | ||||
| Save agent thought | Save agent thought | ||||
| """ | """ | ||||
| agent_thought = db.session.query(MessageAgentThought).filter( | |||||
| MessageAgentThought.id == agent_thought.id | |||||
| ).first() | |||||
| if thought is not None: | if thought is not None: | ||||
| agent_thought.thought = thought | agent_thought.thought = thought | ||||
| agent_thought.tool_labels_str = json.dumps(labels) | agent_thought.tool_labels_str = json.dumps(labels) | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.close() | |||||
| def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: | def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: | ||||
| """ | """ | ||||
| """ | """ | ||||
| convert tool variables to db variables | convert tool variables to db variables | ||||
| """ | """ | ||||
| db_variables = db.session.query(ToolConversationVariables).filter( | |||||
| ToolConversationVariables.conversation_id == self.message.conversation_id, | |||||
| ).first() | |||||
| db_variables.updated_at = datetime.utcnow() | db_variables.updated_at = datetime.utcnow() | ||||
| db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) | db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.close() | |||||
| def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | ||||
| """ | """ | ||||
| if message.answer: | if message.answer: | ||||
| result.append(AssistantPromptMessage(content=message.answer)) | result.append(AssistantPromptMessage(content=message.answer)) | ||||
| db.session.close() | |||||
| return result | return result |