| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | ||||
| from core.prompt.simple_prompt_transform import ModelMode | from core.prompt.simple_prompt_transform import ModelMode | ||||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | from core.prompt.utils.prompt_message_util import PromptMessageUtil | ||||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | from core.workflow.entities.base_node_data_entities import BaseNodeData | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| QUESTION_CLASSIFIER_USER_PROMPT_2, | QUESTION_CLASSIFIER_USER_PROMPT_2, | ||||
| QUESTION_CLASSIFIER_USER_PROMPT_3, | QUESTION_CLASSIFIER_USER_PROMPT_3, | ||||
| ) | ) | ||||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||||
| from libs.json_in_md_parser import parse_and_check_json_markdown | from libs.json_in_md_parser import parse_and_check_json_markdown | ||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| model_instance, model_config = self._fetch_model_config(node_data.model) | model_instance, model_config = self._fetch_model_config(node_data.model) | ||||
| # fetch memory | # fetch memory | ||||
| memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | ||||
| # fetch instruction | |||||
| instruction = self._format_instruction(node_data.instruction, variable_pool) | |||||
| node_data.instruction = instruction | |||||
| # fetch prompt messages | # fetch prompt messages | ||||
| prompt_messages, stop = self._fetch_prompt( | prompt_messages, stop = self._fetch_prompt( | ||||
| node_data=node_data, | node_data=node_data, | ||||
| node_data = node_data | node_data = node_data | ||||
| node_data = cast(cls._node_data_cls, node_data) | node_data = cast(cls._node_data_cls, node_data) | ||||
| variable_mapping = {'query': node_data.query_variable_selector} | variable_mapping = {'query': node_data.query_variable_selector} | ||||
| variable_selectors = [] | |||||
| if node_data.instruction: | |||||
| variable_template_parser = VariableTemplateParser(template=node_data.instruction) | |||||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||||
| for variable_selector in variable_selectors: | |||||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||||
| return variable_mapping | return variable_mapping | ||||
| @classmethod | @classmethod | ||||
| text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, | text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, | ||||
| input_text=input_text, | input_text=input_text, | ||||
| categories=json.dumps(categories), | categories=json.dumps(categories), | ||||
| classification_instructions=instruction, ensure_ascii=False) | |||||
| classification_instructions=instruction, | |||||
| ensure_ascii=False) | |||||
| ) | ) | ||||
| else: | else: | ||||
| raise ValueError(f"Model mode {model_mode} not support.") | raise ValueError(f"Model mode {model_mode} not support.") | ||||
| def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: | |||||
| inputs = {} | |||||
| variable_selectors = [] | |||||
| variable_template_parser = VariableTemplateParser(template=instruction) | |||||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||||
| for variable_selector in variable_selectors: | |||||
| variable_value = variable_pool.get_variable_value(variable_selector.value_selector) | |||||
| if variable_value is None: | |||||
| raise ValueError(f'Variable {variable_selector.variable} not found') | |||||
| inputs[variable_selector.variable] = variable_value | |||||
| prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) | |||||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||||
| instruction = prompt_template.format( | |||||
| prompt_inputs | |||||
| ) | |||||
| return instruction |