Browse Source

Feat/optimize chat prompt (#158)

tags/0.2.2
John Wang 2 years ago
parent
commit
90150a6ca9
No account linked to committer's email address
1 changed files with 38 additions and 33 deletions
  1. 38
    33
      api/core/completion.py

+ 38
- 33
api/core/completion.py View File

memory = cls.get_memory_from_conversation( memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation
conversation=conversation,
return_messages=False
) )


inputs = conversation.inputs inputs = conversation.inputs
return response return response


@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Union[str | List[BaseMessage]]: Union[str | List[BaseMessage]]:
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
"query": query "query": query
} }


human_message_prompt = "{query}"
human_message_prompt = ""

if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}

if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)


if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
human_message_instruction = """Use the following CONTEXT as your learned knowledge.
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT] [CONTEXT]
{context} {context}
[END CONTEXT] [END CONTEXT]
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
""" """
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_instruction += pre_prompt + "\n"

human_message_prompt = human_message_instruction + "Q:{query}\nA:"
else:
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_prompt = pre_prompt + "\n" + human_message_prompt


# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
if pre_prompt:
human_message_prompt += pre_prompt

query_prompt = "\nHuman: {query}\nAI: "


if memory: if memory:
# append chat histories # append chat histories
tmp_messages = messages.copy() + [human_message]
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
rest_tokens = llm_constant.max_context_token_length[
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)

curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0) rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
messages += history_messages
human_message_prompt += "\n\n" + history_messages

human_message_prompt += query_prompt

# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)


messages.append(human_message) messages.append(human_message)




@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \ max_token_limit: int) -> \
List[BaseMessage]:
str:
"""Get memory messages.""" """Get memory messages."""
memory.max_token_limit = max_token_limit memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0] memory_key = memory.memory_variables[0]

Loading…
Cancel
Save