| @@ -12,6 +12,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -26,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import ( | |||
| QUESTION_CLASSIFIER_USER_PROMPT_2, | |||
| QUESTION_CLASSIFIER_USER_PROMPT_3, | |||
| ) | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -47,6 +49,9 @@ class QuestionClassifierNode(LLMNode): | |||
| model_instance, model_config = self._fetch_model_config(node_data.model) | |||
| # fetch memory | |||
| memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | |||
| # fetch instruction | |||
| instruction = self._format_instruction(node_data.instruction, variable_pool) | |||
| node_data.instruction = instruction | |||
| # fetch prompt messages | |||
| prompt_messages, stop = self._fetch_prompt( | |||
| node_data=node_data, | |||
| @@ -122,6 +127,12 @@ class QuestionClassifierNode(LLMNode): | |||
| node_data = node_data | |||
| node_data = cast(cls._node_data_cls, node_data) | |||
| variable_mapping = {'query': node_data.query_variable_selector} | |||
| variable_selectors = [] | |||
| if node_data.instruction: | |||
| variable_template_parser = VariableTemplateParser(template=node_data.instruction) | |||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||
| for variable_selector in variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| return variable_mapping | |||
| @classmethod | |||
| @@ -269,8 +280,30 @@ class QuestionClassifierNode(LLMNode): | |||
| text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, | |||
| input_text=input_text, | |||
| categories=json.dumps(categories), | |||
| classification_instructions=instruction, ensure_ascii=False) | |||
| classification_instructions=instruction, | |||
| ensure_ascii=False) | |||
| ) | |||
| else: | |||
| raise ValueError(f"Model mode {model_mode} not support.") | |||
| def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: | |||
| inputs = {} | |||
| variable_selectors = [] | |||
| variable_template_parser = VariableTemplateParser(template=instruction) | |||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||
| for variable_selector in variable_selectors: | |||
| variable_value = variable_pool.get_variable_value(variable_selector.value_selector) | |||
| if variable_value is None: | |||
| raise ValueError(f'Variable {variable_selector.variable} not found') | |||
| inputs[variable_selector.variable] = variable_value | |||
| prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| instruction = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| return instruction | |||