소스 검색

feat: optimize db connection when llm invoking (#2774)

tags/0.5.9
takatost 1 년 전
부모
커밋
f073dca22a
No account linked to committer's email address

+ 4
- 0
api/core/app_runner/assistant_app_runner.py 파일 보기

if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING


db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()

# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(

+ 2
- 0
api/core/app_runner/basic_app_runner.py 파일 보기

model=app_orchestration_config.model_config.model model=app_orchestration_config.model_config.model
) )


db.session.close()

invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_orchestration_config.model_config.parameters,

+ 8
- 0
api/core/app_runner/generate_task_pipeline.py 파일 보기

Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()

if stream: if stream:
return self._process_stream_response() return self._process_stream_response()
else: else:
.first() .first()
) )
db.session.refresh(agent_thought) db.session.refresh(agent_thought)
db.session.close()


if agent_thought: if agent_thought:
response = { response = {
.filter(MessageFile.id == event.message_file_id) .filter(MessageFile.id == event.message_file_id)
.first() .first()
) )
db.session.close()

# get extension # get extension
if '.' in message_file.url: if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}' extension = f'.{message_file.url.split(".")[-1]}'
usage = llm_result.usage usage = llm_result.usage


self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()


self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens

+ 3
- 3
api/core/application_manager.py 파일 보기

logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove()
db.session.close()


def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: ApplicationQueueManager,
else: else:
logger.exception(e) logger.exception(e)
raise e raise e
finally:
db.session.remove()


def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity: -> AppOrchestrationConfigEntity:


db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
db.session.refresh(conversation)
else: else:
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)


db.session.add(message) db.session.add(message)
db.session.commit() db.session.commit()
db.session.refresh(message)


for file in application_generate_entity.files: for file in application_generate_entity.files:
message_file = MessageFile( message_file = MessageFile(

+ 20
- 2
api/core/features/assistant_base_runner.py 파일 보기

self.agent_thought_count = db.session.query(MessageAgentThought).filter( self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()
db.session.close()


# check if model supports stream tool call # check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
created_by=self.user_id, created_by=self.user_id,
) )
db.session.add(message_file) db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)

result.append(( result.append((
message_file, message_file,
message.save_as message.save_as
)) ))
db.session.commit()


db.session.close()
return result return result
def create_agent_thought(self, message_id: str, message: str, def create_agent_thought(self, message_id: str, message: str,


db.session.add(thought) db.session.add(thought)
db.session.commit() db.session.commit()
db.session.refresh(thought)
db.session.close()


self.agent_thought_count += 1 self.agent_thought_count += 1


""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()

if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought


agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)


db.session.commit() db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
""" """
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()

db_variables.updated_at = datetime.utcnow() db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close()


def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
if message.answer: if message.answer:
result.append(AssistantPromptMessage(content=message.answer)) result.append(AssistantPromptMessage(content=message.answer))


db.session.close()

return result return result

Loading…
취소
저장