소스 검색

feat: support LLM jinja2 template prompt (#3968)

Co-authored-by: Joel <iamjoel007@gmail.com>
tags/0.6.8
Yeuoly 1 년 전
부모
커밋
8578ee0864
No account linked to committer's email address

+ 17
- 0
api/core/helper/code_executor/jinja2_formatter.py 파일 보기

from core.helper.code_executor.code_executor import CodeExecutor


class Jinja2Formatter:
@classmethod
def format(cls, template: str, inputs: str) -> str:
"""
Format template
:param template: template
:param inputs: inputs
:return:
"""
result = CodeExecutor.execute_workflow_code_template(
language='jinja2', code=template, inputs=inputs
)

return result['result']

+ 40
- 25
api/core/prompt/advanced_prompt_transform.py 파일 보기



from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,


prompt_messages = [] prompt_messages = []


prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}


prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)


if memory and memory_config:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
memory_config=memory_config,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
)
if memory and memory_config:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
memory_config=memory_config,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
)


if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)


prompt = prompt_template.format(
prompt_inputs
)
prompt = prompt_template.format(
prompt_inputs
)
else:
prompt = raw_prompt
prompt_inputs = inputs

prompt = Jinja2Formatter.format(prompt, prompt_inputs)


if files: if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)] prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for prompt_item in raw_prompt_list: for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text raw_prompt = prompt_item.text


prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}


prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)


prompt = prompt_template.format(
prompt_inputs
)
prompt = prompt_template.format(
prompt_inputs
)
elif prompt_item.edition_type == 'jinja2':
prompt = raw_prompt
prompt_inputs = inputs

prompt = Jinja2Formatter.format(prompt, prompt_inputs)
else:
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')


if prompt_item.role == PromptMessageRole.USER: if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt)) prompt_messages.append(UserPromptMessage(content=prompt))

+ 3
- 1
api/core/prompt/entities/advanced_prompt_entities.py 파일 보기

from typing import Optional
from typing import Literal, Optional


from pydantic import BaseModel from pydantic import BaseModel


""" """
text: str text: str
role: PromptMessageRole role: PromptMessageRole
edition_type: Optional[Literal['basic', 'jinja2']]




class CompletionModelPromptTemplate(BaseModel): class CompletionModelPromptTemplate(BaseModel):
Completion Model Prompt Template. Completion Model Prompt Template.
""" """
text: str text: str
edition_type: Optional[Literal['basic', 'jinja2']]




class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):

+ 20
- 1
api/core/workflow/nodes/llm/entities.py 파일 보기



from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector




class ModelConfig(BaseModel): class ModelConfig(BaseModel):
enabled: bool enabled: bool
configs: Optional[Configs] = None configs: Optional[Configs] = None


class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None

class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None

class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None


class LLMNodeData(BaseNodeData): class LLMNodeData(BaseNodeData):
""" """
LLM Node Data. LLM Node Data.
""" """
model: ModelConfig model: ModelConfig
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig vision: VisionConfig

+ 125
- 13
api/core/workflow/nodes/llm/llm_node.py 파일 보기

import json
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy
from typing import Optional, cast from typing import Optional, cast


from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation from models.model import Conversation
:param variable_pool: variable pool :param variable_pool: variable pool
:return: :return:
""" """
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(LLMNodeData, deepcopy(self.node_data))


node_inputs = None node_inputs = None
process_data = None process_data = None


try: try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)

# fetch variables and fetch values from variable pool # fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool) inputs = self._fetch_inputs(node_data, variable_pool)


# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)

# merge inputs
inputs.update(jinja_inputs)

node_inputs = {} node_inputs = {}


# fetch files # fetch files
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()


return full_text, usage return full_text, usage
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages

:param messages: chat messages
:return:
"""

if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
messages.text = messages.jinja2_text

return messages

for message in messages:
if message.edition_type == 'jinja2':
message.text = message.jinja2_text

return messages

def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
variables = {}

if not node_data.prompt_config:
return variables

for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)

def parse_dict(d: dict) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)
if isinstance(value, str):
value = value
elif isinstance(value, list):
result = ''
for item in value:
if isinstance(item, dict):
result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else:
result += str(item)
result += '\n'
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
else:
value = str(value)

variables[variable] = value

return variables


def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
""" """
db.session.commit() db.session.commit()


@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:param node_data: node data :param node_data: node data
:return: :return:
""" """
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)


prompt_template = node_data.prompt_template prompt_template = node_data.prompt_template


variable_selectors = [] variable_selectors = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
for prompt in prompt_template: for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
if prompt.edition_type != 'jinja2':
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else: else:
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
if prompt_template.edition_type != 'jinja2':
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()


variable_mapping = {} variable_mapping = {}
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
if node_data.memory: if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]


if node_data.prompt_config:
enable_jinja = False

if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == 'jinja2':
enable_jinja = True
break
else:
if prompt_template.edition_type == 'jinja2':
enable_jinja = True

if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector

return variable_mapping return variable_mapping


@classmethod @classmethod
"prompts": [ "prompts": [
{ {
"role": "system", "role": "system",
"text": "You are a helpful AI assistant."
"text": "You are a helpful AI assistant.",
"edition_type": "basic"
} }
] ]
}, },
"prompt": { "prompt": {
"text": "Here is the chat histories between human and assistant, inside " "text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{" "<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic"
}, },
"stop": ["Human:"] "stop": ["Human:"]
} }

+ 1
- 0
api/requirements-dev.txt 파일 보기

pytest-benchmark~=4.0.0 pytest-benchmark~=4.0.0
pytest-env~=1.1.3 pytest-env~=1.1.3
pytest-mock~=3.14.0 pytest-mock~=3.14.0
jinja2~=3.1.2

+ 2
- 1
api/tests/integration_tests/workflow/nodes/__mock/code_executor.py 파일 보기



import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from jinja2 import Template


from core.helper.code_executor.code_executor import CodeExecutor from core.helper.code_executor.code_executor import CodeExecutor


} }
elif language == 'jinja2': elif language == 'jinja2':
return { return {
"result": "3"
"result": Template(code).render(inputs)
} }


@pytest.fixture @pytest.fixture

+ 117
- 0
api/tests/integration_tests/workflow/nodes/test_llm.py 파일 보기

import json
import os import os
from unittest.mock import MagicMock from unittest.mock import MagicMock




"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock




@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['text'] is not None assert result.outputs['text'] is not None
assert result.outputs['usage']['total_tokens'] > 0 assert result.outputs['usage']['total_tokens'] > 0

@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
"""
Test execute LLM node with jinja2
"""
node = LLMNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
config={
'id': 'llm',
'data': {
'title': '123',
'type': 'llm',
'model': {
'provider': 'openai',
'name': 'gpt-3.5-turbo',
'mode': 'chat',
'completion_params': {}
},
'prompt_config': {
'jinja2_variables': [{
'variable': 'sys_query',
'value_selector': ['sys', 'query']
}, {
'variable': 'output',
'value_selector': ['abc', 'output']
}]
},
'prompt_template': [
{
'role': 'system',
'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
'edition_type': 'jinja2'
},
{
'role': 'user',
'text': '{{#sys.query#}}',
'jinja2_text': '{{sys_query}}',
'edition_type': 'basic'
}
],
'memory': None,
'context': {
'enabled': False
},
'vision': {
'enabled': False
}
}
}
)

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')

credentials = {
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}

provider_instance = ModelProviderFactory().get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id='1',
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(
enabled=False
),
custom_configuration=CustomConfiguration(
provider=CustomProviderConfiguration(
credentials=credentials
)
)
),
provider_instance=provider_instance,
model_type_instance=model_type_instance
)

model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')

model_config = ModelConfigWithCredentialsEntity(
model='gpt-3.5-turbo',
provider='openai',
mode='chat',
credentials=credentials,
parameters={},
model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
provider_model_bundle=provider_model_bundle
)

# Mock db.session.close()
db.session.close = MagicMock()

node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))

# execute node
result = node.run(pool)

assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert 'sunny' in json.dumps(result.process_data)
assert 'what\'s the weather today?' in json.dumps(result.process_data)

Loading…
취소
저장