| @@ -39,7 +39,8 @@ class Completion: | |||
| memory = cls.get_memory_from_conversation( | |||
| tenant_id=app.tenant_id, | |||
| app_model_config=app_model_config, | |||
| conversation=conversation | |||
| conversation=conversation, | |||
| return_messages=False | |||
| ) | |||
| inputs = conversation.inputs | |||
| @@ -119,7 +120,8 @@ class Completion: | |||
| return response | |||
| @classmethod | |||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], | |||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, | |||
| chain_output: Optional[str], | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ | |||
| Union[str | List[BaseMessage]]: | |||
| pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt | |||
| @@ -161,11 +163,19 @@ And answer according to the language of the user's question. | |||
| "query": query | |||
| } | |||
| human_message_prompt = "{query}" | |||
| human_message_prompt = "" | |||
| if pre_prompt: | |||
| pre_prompt_inputs = {k: inputs[k] for k in | |||
| OutLinePromptTemplate.from_template(template=pre_prompt).input_variables | |||
| if k in inputs} | |||
| if pre_prompt_inputs: | |||
| human_inputs.update(pre_prompt_inputs) | |||
| if chain_output: | |||
| human_inputs['context'] = chain_output | |||
| human_message_instruction = """Use the following CONTEXT as your learned knowledge. | |||
| human_message_prompt += """Use the following CONTEXT as your learned knowledge. | |||
| [CONTEXT] | |||
| {context} | |||
| [END CONTEXT] | |||
| @@ -176,39 +186,33 @@ When answer to user: | |||
| Avoid mentioning that you obtained the information from the context. | |||
| And answer according to the language of the user's question. | |||
| """ | |||
| if pre_prompt: | |||
| extra_inputs = {k: inputs[k] for k in | |||
| OutLinePromptTemplate.from_template(template=pre_prompt).input_variables | |||
| if k in inputs} | |||
| if extra_inputs: | |||
| human_inputs.update(extra_inputs) | |||
| human_message_instruction += pre_prompt + "\n" | |||
| human_message_prompt = human_message_instruction + "Q:{query}\nA:" | |||
| else: | |||
| if pre_prompt: | |||
| extra_inputs = {k: inputs[k] for k in | |||
| OutLinePromptTemplate.from_template(template=pre_prompt).input_variables | |||
| if k in inputs} | |||
| if extra_inputs: | |||
| human_inputs.update(extra_inputs) | |||
| human_message_prompt = pre_prompt + "\n" + human_message_prompt | |||
| # construct main prompt | |||
| human_message = PromptBuilder.to_human_message( | |||
| prompt_content=human_message_prompt, | |||
| inputs=human_inputs | |||
| ) | |||
| if pre_prompt: | |||
| human_message_prompt += pre_prompt | |||
| query_prompt = "\nHuman: {query}\nAI: " | |||
| if memory: | |||
| # append chat histories | |||
| tmp_messages = messages.copy() + [human_message] | |||
| curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages) | |||
| rest_tokens = llm_constant.max_context_token_length[ | |||
| memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens | |||
| tmp_human_message = PromptBuilder.to_human_message( | |||
| prompt_content=human_message_prompt + query_prompt, | |||
| inputs=human_inputs | |||
| ) | |||
| curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) | |||
| rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ | |||
| - memory.llm.max_tokens - curr_message_tokens | |||
| rest_tokens = max(rest_tokens, 0) | |||
| history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) | |||
| messages += history_messages | |||
| human_message_prompt += "\n\n" + history_messages | |||
| human_message_prompt += query_prompt | |||
| # construct main prompt | |||
| human_message = PromptBuilder.to_human_message( | |||
| prompt_content=human_message_prompt, | |||
| inputs=human_inputs | |||
| ) | |||
| messages.append(human_message) | |||
| @@ -216,7 +220,8 @@ And answer according to the language of the user's question. | |||
| @classmethod | |||
| def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager: | |||
| streaming: bool, | |||
| conversation_message_task: ConversationMessageTask) -> CallbackManager: | |||
| llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | |||
| if streaming: | |||
| callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||
| @@ -228,7 +233,7 @@ And answer according to the language of the user's question. | |||
| @classmethod | |||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | |||
| max_token_limit: int) -> \ | |||
| List[BaseMessage]: | |||
| str: | |||
| """Get memory messages.""" | |||
| memory.max_token_limit = max_token_limit | |||
| memory_key = memory.memory_variables[0] | |||