Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.6.8
| @@ -0,0 +1,17 @@ | |||
| from core.helper.code_executor.code_executor import CodeExecutor | |||
| class Jinja2Formatter: | |||
| @classmethod | |||
| def format(cls, template: str, inputs: str) -> str: | |||
| """ | |||
| Format template | |||
| :param template: template | |||
| :param inputs: inputs | |||
| :return: | |||
| """ | |||
| result = CodeExecutor.execute_workflow_code_template( | |||
| language='jinja2', code=template, inputs=inputs | |||
| ) | |||
| return result['result'] | |||
| @@ -2,6 +2,7 @@ from typing import Optional, Union | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.file.file_obj import FileVar | |||
| from core.helper.code_executor.jinja2_formatter import Jinja2Formatter | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| @@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_messages = [] | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| if memory and memory_config: | |||
| role_prefix = memory_config.role_prefix | |||
| prompt_inputs = self._set_histories_variable( | |||
| memory=memory, | |||
| memory_config=memory_config, | |||
| raw_prompt=raw_prompt, | |||
| role_prefix=role_prefix, | |||
| prompt_template=prompt_template, | |||
| prompt_inputs=prompt_inputs, | |||
| model_config=model_config | |||
| ) | |||
| if memory and memory_config: | |||
| role_prefix = memory_config.role_prefix | |||
| prompt_inputs = self._set_histories_variable( | |||
| memory=memory, | |||
| memory_config=memory_config, | |||
| raw_prompt=raw_prompt, | |||
| role_prefix=role_prefix, | |||
| prompt_template=prompt_template, | |||
| prompt_inputs=prompt_inputs, | |||
| model_config=model_config | |||
| ) | |||
| if query: | |||
| prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) | |||
| if query: | |||
| prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) | |||
| prompt = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| prompt = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| else: | |||
| prompt = raw_prompt | |||
| prompt_inputs = inputs | |||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | |||
| if files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=prompt)] | |||
| @@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform): | |||
| for prompt_item in raw_prompt_list: | |||
| raw_prompt = prompt_item.text | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| prompt = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| elif prompt_item.edition_type == 'jinja2': | |||
| prompt = raw_prompt | |||
| prompt_inputs = inputs | |||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | |||
| else: | |||
| raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') | |||
| if prompt_item.role == PromptMessageRole.USER: | |||
| prompt_messages.append(UserPromptMessage(content=prompt)) | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Optional | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| @@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel): | |||
| """ | |||
| text: str | |||
| role: PromptMessageRole | |||
| edition_type: Optional[Literal['basic', 'jinja2']] | |||
| class CompletionModelPromptTemplate(BaseModel): | |||
| @@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel): | |||
| Completion Model Prompt Template. | |||
| """ | |||
| text: str | |||
| edition_type: Optional[Literal['basic', 'jinja2']] | |||
| class MemoryConfig(BaseModel): | |||
| @@ -4,6 +4,7 @@ from pydantic import BaseModel | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| class ModelConfig(BaseModel): | |||
| @@ -37,13 +38,31 @@ class VisionConfig(BaseModel): | |||
| enabled: bool | |||
| configs: Optional[Configs] = None | |||
| class PromptConfig(BaseModel): | |||
| """ | |||
| Prompt Config. | |||
| """ | |||
| jinja2_variables: Optional[list[VariableSelector]] = None | |||
| class LLMNodeChatModelMessage(ChatModelMessage): | |||
| """ | |||
| LLM Node Chat Model Message. | |||
| """ | |||
| jinja2_text: Optional[str] = None | |||
| class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): | |||
| """ | |||
| LLM Node Chat Model Prompt Template. | |||
| """ | |||
| jinja2_text: Optional[str] = None | |||
| class LLMNodeData(BaseNodeData): | |||
| """ | |||
| LLM Node Data. | |||
| """ | |||
| model: ModelConfig | |||
| prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] | |||
| prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] | |||
| prompt_config: Optional[PromptConfig] = None | |||
| memory: Optional[MemoryConfig] = None | |||
| context: ContextConfig | |||
| vision: VisionConfig | |||
| @@ -1,4 +1,6 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from copy import deepcopy | |||
| from typing import Optional, cast | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| @@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig | |||
| from core.workflow.nodes.llm.entities import ( | |||
| LLMNodeChatModelMessage, | |||
| LLMNodeCompletionModelPromptTemplate, | |||
| LLMNodeData, | |||
| ModelConfig, | |||
| ) | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation | |||
| @@ -39,16 +45,24 @@ class LLMNode(BaseNode): | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = self.node_data | |||
| node_data = cast(self._node_data_cls, node_data) | |||
| node_data = cast(LLMNodeData, deepcopy(self.node_data)) | |||
| node_inputs = None | |||
| process_data = None | |||
| try: | |||
| # init messages template | |||
| node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) | |||
| # fetch variables and fetch values from variable pool | |||
| inputs = self._fetch_inputs(node_data, variable_pool) | |||
| # fetch jinja2 inputs | |||
| jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) | |||
| # merge inputs | |||
| inputs.update(jinja_inputs) | |||
| node_inputs = {} | |||
| # fetch files | |||
| @@ -183,6 +197,86 @@ class LLMNode(BaseNode): | |||
| usage = LLMUsage.empty_usage() | |||
| return full_text, usage | |||
| def _transform_chat_messages(self, | |||
| messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | |||
| ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: | |||
| """ | |||
| Transform chat messages | |||
| :param messages: chat messages | |||
| :return: | |||
| """ | |||
| if isinstance(messages, LLMNodeCompletionModelPromptTemplate): | |||
| if messages.edition_type == 'jinja2': | |||
| messages.text = messages.jinja2_text | |||
| return messages | |||
| for message in messages: | |||
| if message.edition_type == 'jinja2': | |||
| message.text = message.jinja2_text | |||
| return messages | |||
| def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: | |||
| """ | |||
| Fetch jinja inputs | |||
| :param node_data: node data | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| variables = {} | |||
| if not node_data.prompt_config: | |||
| return variables | |||
| for variable_selector in node_data.prompt_config.jinja2_variables or []: | |||
| variable = variable_selector.variable | |||
| value = variable_pool.get_variable_value( | |||
| variable_selector=variable_selector.value_selector | |||
| ) | |||
| def parse_dict(d: dict) -> str: | |||
| """ | |||
| Parse dict into string | |||
| """ | |||
| # check if it's a context structure | |||
| if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: | |||
| return d['content'] | |||
| # else, parse the dict | |||
| try: | |||
| return json.dumps(d, ensure_ascii=False) | |||
| except Exception: | |||
| return str(d) | |||
| if isinstance(value, str): | |||
| value = value | |||
| elif isinstance(value, list): | |||
| result = '' | |||
| for item in value: | |||
| if isinstance(item, dict): | |||
| result += parse_dict(item) | |||
| elif isinstance(item, str): | |||
| result += item | |||
| elif isinstance(item, int | float): | |||
| result += str(item) | |||
| else: | |||
| result += str(item) | |||
| result += '\n' | |||
| value = result.strip() | |||
| elif isinstance(value, dict): | |||
| value = parse_dict(value) | |||
| elif isinstance(value, int | float): | |||
| value = str(value) | |||
| else: | |||
| value = str(value) | |||
| variables[variable] = value | |||
| return variables | |||
| def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: | |||
| """ | |||
| @@ -531,25 +625,25 @@ class LLMNode(BaseNode): | |||
| db.session.commit() | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| node_data = node_data | |||
| node_data = cast(cls._node_data_cls, node_data) | |||
| prompt_template = node_data.prompt_template | |||
| variable_selectors = [] | |||
| if isinstance(prompt_template, list): | |||
| for prompt in prompt_template: | |||
| variable_template_parser = VariableTemplateParser(template=prompt.text) | |||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||
| if prompt.edition_type != 'jinja2': | |||
| variable_template_parser = VariableTemplateParser(template=prompt.text) | |||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||
| else: | |||
| variable_template_parser = VariableTemplateParser(template=prompt_template.text) | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| if prompt_template.edition_type != 'jinja2': | |||
| variable_template_parser = VariableTemplateParser(template=prompt_template.text) | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| variable_mapping = {} | |||
| for variable_selector in variable_selectors: | |||
| @@ -571,6 +665,22 @@ class LLMNode(BaseNode): | |||
| if node_data.memory: | |||
| variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] | |||
| if node_data.prompt_config: | |||
| enable_jinja = False | |||
| if isinstance(prompt_template, list): | |||
| for prompt in prompt_template: | |||
| if prompt.edition_type == 'jinja2': | |||
| enable_jinja = True | |||
| break | |||
| else: | |||
| if prompt_template.edition_type == 'jinja2': | |||
| enable_jinja = True | |||
| if enable_jinja: | |||
| for variable_selector in node_data.prompt_config.jinja2_variables or []: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| return variable_mapping | |||
| @classmethod | |||
| @@ -588,7 +698,8 @@ class LLMNode(BaseNode): | |||
| "prompts": [ | |||
| { | |||
| "role": "system", | |||
| "text": "You are a helpful AI assistant." | |||
| "text": "You are a helpful AI assistant.", | |||
| "edition_type": "basic" | |||
| } | |||
| ] | |||
| }, | |||
| @@ -600,7 +711,8 @@ class LLMNode(BaseNode): | |||
| "prompt": { | |||
| "text": "Here is the chat histories between human and assistant, inside " | |||
| "<histories></histories> XML tags.\n\n<histories>\n{{" | |||
| "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:" | |||
| "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", | |||
| "edition_type": "basic" | |||
| }, | |||
| "stop": ["Human:"] | |||
| } | |||
| @@ -3,3 +3,4 @@ pytest~=8.1.1 | |||
| pytest-benchmark~=4.0.0 | |||
| pytest-env~=1.1.3 | |||
| pytest-mock~=3.14.0 | |||
| jinja2~=3.1.2 | |||
| @@ -3,6 +3,7 @@ from typing import Literal | |||
| import pytest | |||
| from _pytest.monkeypatch import MonkeyPatch | |||
| from jinja2 import Template | |||
| from core.helper.code_executor.code_executor import CodeExecutor | |||
| @@ -18,7 +19,7 @@ class MockedCodeExecutor: | |||
| } | |||
| elif language == 'jinja2': | |||
| return { | |||
| "result": "3" | |||
| "result": Template(code).render(inputs) | |||
| } | |||
| @pytest.fixture | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import MagicMock | |||
| @@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus | |||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | |||
| from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) | |||
| @@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock): | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs['text'] is not None | |||
| assert result.outputs['usage']['total_tokens'] > 0 | |||
| @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) | |||
| @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) | |||
| def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| """ | |||
| Test execute LLM node with jinja2 | |||
| """ | |||
| node = LLMNode( | |||
| tenant_id='1', | |||
| app_id='1', | |||
| workflow_id='1', | |||
| user_id='1', | |||
| user_from=UserFrom.ACCOUNT, | |||
| config={ | |||
| 'id': 'llm', | |||
| 'data': { | |||
| 'title': '123', | |||
| 'type': 'llm', | |||
| 'model': { | |||
| 'provider': 'openai', | |||
| 'name': 'gpt-3.5-turbo', | |||
| 'mode': 'chat', | |||
| 'completion_params': {} | |||
| }, | |||
| 'prompt_config': { | |||
| 'jinja2_variables': [{ | |||
| 'variable': 'sys_query', | |||
| 'value_selector': ['sys', 'query'] | |||
| }, { | |||
| 'variable': 'output', | |||
| 'value_selector': ['abc', 'output'] | |||
| }] | |||
| }, | |||
| 'prompt_template': [ | |||
| { | |||
| 'role': 'system', | |||
| 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', | |||
| 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', | |||
| 'edition_type': 'jinja2' | |||
| }, | |||
| { | |||
| 'role': 'user', | |||
| 'text': '{{#sys.query#}}', | |||
| 'jinja2_text': '{{sys_query}}', | |||
| 'edition_type': 'basic' | |||
| } | |||
| ], | |||
| 'memory': None, | |||
| 'context': { | |||
| 'enabled': False | |||
| }, | |||
| 'vision': { | |||
| 'enabled': False | |||
| } | |||
| } | |||
| } | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool(system_variables={ | |||
| SystemVariable.QUERY: 'what\'s the weather today?', | |||
| SystemVariable.FILES: [], | |||
| SystemVariable.CONVERSATION_ID: 'abababa', | |||
| SystemVariable.USER_ID: 'aaa' | |||
| }, user_inputs={}) | |||
| pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') | |||
| credentials = { | |||
| 'openai_api_key': os.environ.get('OPENAI_API_KEY') | |||
| } | |||
| provider_instance = ModelProviderFactory().get_provider_instance('openai') | |||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||
| provider_model_bundle = ProviderModelBundle( | |||
| configuration=ProviderConfiguration( | |||
| tenant_id='1', | |||
| provider=provider_instance.get_provider_schema(), | |||
| preferred_provider_type=ProviderType.CUSTOM, | |||
| using_provider_type=ProviderType.CUSTOM, | |||
| system_configuration=SystemConfiguration( | |||
| enabled=False | |||
| ), | |||
| custom_configuration=CustomConfiguration( | |||
| provider=CustomProviderConfiguration( | |||
| credentials=credentials | |||
| ) | |||
| ) | |||
| ), | |||
| provider_instance=provider_instance, | |||
| model_type_instance=model_type_instance | |||
| ) | |||
| model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') | |||
| model_config = ModelConfigWithCredentialsEntity( | |||
| model='gpt-3.5-turbo', | |||
| provider='openai', | |||
| mode='chat', | |||
| credentials=credentials, | |||
| parameters={}, | |||
| model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), | |||
| provider_model_bundle=provider_model_bundle | |||
| ) | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) | |||
| # execute node | |||
| result = node.run(pool) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert 'sunny' in json.dumps(result.process_data) | |||
| assert 'what\'s the weather today?' in json.dumps(result.process_data) | |||