Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.6.5
| @@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform): | |||
| context: Optional[str], | |||
| memory_config: Optional[MemoryConfig], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| query_prompt_template: Optional[str] = None) -> list[PromptMessage]: | |||
| inputs = {key: str(value) for key, value in inputs.items()} | |||
| prompt_messages = [] | |||
| @@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_template=prompt_template, | |||
| inputs=inputs, | |||
| query=query, | |||
| query_prompt_template=query_prompt_template, | |||
| files=files, | |||
| context=context, | |||
| memory_config=memory_config, | |||
| @@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform): | |||
| context: Optional[str], | |||
| memory_config: Optional[MemoryConfig], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| query_prompt_template: Optional[str] = None) -> list[PromptMessage]: | |||
| """ | |||
| Get chat model prompt messages. | |||
| """ | |||
| @@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform): | |||
| elif prompt_item.role == PromptMessageRole.ASSISTANT: | |||
| prompt_messages.append(AssistantPromptMessage(content=prompt)) | |||
| if query and query_prompt_template: | |||
| prompt_template = PromptTemplateParser( | |||
| template=query_prompt_template, | |||
| with_variable_tmpl=self.with_variable_tmpl | |||
| ) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt_inputs['#sys.query#'] = query | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| query = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| if memory and memory_config: | |||
| prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | |||
| @@ -40,3 +40,4 @@ class MemoryConfig(BaseModel): | |||
| role_prefix: Optional[RolePrefix] = None | |||
| window: WindowConfig | |||
| query_prompt_template: Optional[str] = None | |||
| @@ -74,6 +74,7 @@ class LLMNode(BaseNode): | |||
| node_data=node_data, | |||
| query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) | |||
| if node_data.memory else None, | |||
| query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, | |||
| inputs=inputs, | |||
| files=files, | |||
| context=context, | |||
| @@ -209,6 +210,17 @@ class LLMNode(BaseNode): | |||
| inputs[variable_selector.variable] = variable_value | |||
| memory = node_data.memory | |||
| if memory and memory.query_prompt_template: | |||
| query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) | |||
| .extract_variable_selectors()) | |||
| for variable_selector in query_variable_selectors: | |||
| variable_value = variable_pool.get_variable_value(variable_selector.value_selector) | |||
| if variable_value is None: | |||
| raise ValueError(f'Variable {variable_selector.variable} not found') | |||
| inputs[variable_selector.variable] = variable_value | |||
| return inputs | |||
| def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: | |||
| @@ -302,7 +314,8 @@ class LLMNode(BaseNode): | |||
| return None | |||
| def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ | |||
| ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| """ | |||
| Fetch model config | |||
| :param node_data_model: node data model | |||
| @@ -407,6 +420,7 @@ class LLMNode(BaseNode): | |||
| def _fetch_prompt_messages(self, node_data: LLMNodeData, | |||
| query: Optional[str], | |||
| query_prompt_template: Optional[str], | |||
| inputs: dict[str, str], | |||
| files: list[FileVar], | |||
| context: Optional[str], | |||
| @@ -417,6 +431,7 @@ class LLMNode(BaseNode): | |||
| Fetch prompt messages | |||
| :param node_data: node data | |||
| :param query: query | |||
| :param query_prompt_template: query prompt template | |||
| :param inputs: inputs | |||
| :param files: files | |||
| :param context: context | |||
| @@ -433,7 +448,8 @@ class LLMNode(BaseNode): | |||
| context=context, | |||
| memory_config=node_data.memory, | |||
| memory=memory, | |||
| model_config=model_config | |||
| model_config=model_config, | |||
| query_prompt_template=query_prompt_template, | |||
| ) | |||
| stop = model_config.stop | |||
| @@ -539,6 +555,13 @@ class LLMNode(BaseNode): | |||
| for variable_selector in variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| memory = node_data.memory | |||
| if memory and memory.query_prompt_template: | |||
| query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) | |||
| .extract_variable_selectors()) | |||
| for variable_selector in query_variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| if node_data.context.enabled: | |||
| variable_mapping['#context#'] = node_data.context.variable_selector | |||
| @@ -30,6 +30,9 @@ export const checkHasQueryBlock = (text: string) => { | |||
| * {{#1711617514996.sys.query#}} => [sys, query] | |||
| */ | |||
| export const getInputVars = (text: string): ValueSelector[] => { | |||
| if (!text) | |||
| return [] | |||
| const allVars = text.match(/{{#([^#]*)#}}/g) | |||
| if (allVars && allVars?.length > 0) { | |||
| // {{#context#}}, {{#query#}} is not input vars | |||
| @@ -146,6 +146,7 @@ const Editor: FC<Props> = ({ | |||
| <PromptEditor | |||
| instanceId={instanceId} | |||
| compact | |||
| className='min-h-[56px]' | |||
| style={isExpand ? { height: editorExpandHeight - 5 } : {}} | |||
| value={value} | |||
| contextBlock={{ | |||
| @@ -272,10 +272,12 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => { | |||
| const payload = (data as LLMNodeType) | |||
| const isChatModel = payload.model?.mode === 'chat' | |||
| let prompts: string[] = [] | |||
| if (isChatModel) | |||
| if (isChatModel) { | |||
| prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || [] | |||
| else | |||
| prompts = [(payload.prompt_template as PromptItem).text] | |||
| if (payload.memory?.query_prompt_template) | |||
| prompts.push(payload.memory.query_prompt_template) | |||
| } | |||
| else { prompts = [(payload.prompt_template as PromptItem).text] } | |||
| const inputVars: ValueSelector[] = matchNotSystemVars(prompts) | |||
| const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : [] | |||
| @@ -375,6 +377,8 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new | |||
| text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector), | |||
| } | |||
| }) | |||
| if (payload.memory?.query_prompt_template) | |||
| payload.memory.query_prompt_template = replaceOldVarInText(payload.memory.query_prompt_template, oldVarSelector, newVarSelector) | |||
| } | |||
| else { | |||
| payload.prompt_template = { | |||
| @@ -50,6 +50,13 @@ const nodeDefault: NodeDefault<LLMNodeType> = { | |||
| if (isPromptyEmpty) | |||
| errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') }) | |||
| } | |||
| if (!errorMessages && !!payload.memory) { | |||
| const isChatModel = payload.model.mode === 'chat' | |||
| // payload.memory.query_prompt_template not pass is default: {{#sys.query#}} | |||
| if (isChatModel && !!payload.memory.query_prompt_template && !payload.memory.query_prompt_template.includes('{{#sys.query#}}')) | |||
| errorMessages = t('workflow.nodes.llm.sysQueryInUser') | |||
| } | |||
| return { | |||
| isValid: !errorMessages, | |||
| errorMessage: errorMessages, | |||
| @@ -50,7 +50,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({ | |||
| handleContextVarChange, | |||
| filterInputVar, | |||
| filterVar, | |||
| availableVars, | |||
| availableNodes, | |||
| handlePromptChange, | |||
| handleSyeQueryChange, | |||
| handleMemoryChange, | |||
| handleVisionResolutionEnabledChange, | |||
| handleVisionResolutionChange, | |||
| @@ -204,19 +207,20 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({ | |||
| <HelpCircle className='w-3.5 h-3.5 text-gray-400' /> | |||
| </TooltipPlus> | |||
| </div>} | |||
| value={'{{#sys.query#}}'} | |||
| onChange={() => { }} | |||
| readOnly | |||
| value={inputs.memory.query_prompt_template || '{{#sys.query#}}'} | |||
| onChange={handleSyeQueryChange} | |||
| readOnly={readOnly} | |||
| isShowContext={false} | |||
| isChatApp | |||
| isChatModel={false} | |||
| hasSetBlockStatus={{ | |||
| query: false, | |||
| history: true, | |||
| context: true, | |||
| }} | |||
| availableNodes={[startNode!]} | |||
| isChatModel | |||
| hasSetBlockStatus={hasSetBlockStatus} | |||
| nodesOutputVars={availableVars} | |||
| availableNodes={availableNodes} | |||
| /> | |||
| {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( | |||
| <div className='leading-[18px] text-xs font-normal text-[#DC6803]'>{t(`${i18nPrefix}.sysQueryInUser`)}</div> | |||
| )} | |||
| </div> | |||
| </div> | |||
| )} | |||
| @@ -8,6 +8,7 @@ import { | |||
| useIsChatMode, | |||
| useNodesReadOnly, | |||
| } from '../../hooks' | |||
| import useAvailableVarList from '../_base/hooks/use-available-var-list' | |||
| import type { LLMNodeType } from './types' | |||
| import { Resolution } from '@/types/app' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| @@ -206,6 +207,24 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs]) | |||
| const handleSyeQueryChange = useCallback((newQuery: string) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| if (!draft.memory) { | |||
| draft.memory = { | |||
| window: { | |||
| enabled: false, | |||
| size: 10, | |||
| }, | |||
| query_prompt_template: newQuery, | |||
| } | |||
| } | |||
| else { | |||
| draft.memory.query_prompt_template = newQuery | |||
| } | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs]) | |||
| const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| if (!draft.vision) { | |||
| @@ -248,6 +267,14 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type) | |||
| }, []) | |||
| const { | |||
| availableVars, | |||
| availableNodes, | |||
| } = useAvailableVarList(id, { | |||
| onlyLeafNodeVar: false, | |||
| filterVar, | |||
| }) | |||
| // single run | |||
| const { | |||
| isShowSingleRun, | |||
| @@ -322,8 +349,10 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| const allVarStrArr = (() => { | |||
| const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text] | |||
| if (isChatMode && isChatModel && !!inputs.memory) | |||
| if (isChatMode && isChatModel && !!inputs.memory) { | |||
| arr.push('{{#sys.query#}}') | |||
| arr.push(inputs.memory.query_prompt_template) | |||
| } | |||
| return arr | |||
| })() | |||
| @@ -346,8 +375,11 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| handleContextVarChange, | |||
| filterInputVar, | |||
| filterVar, | |||
| availableVars, | |||
| availableNodes, | |||
| handlePromptChange, | |||
| handleMemoryChange, | |||
| handleSyeQueryChange, | |||
| handleVisionResolutionEnabledChange, | |||
| handleVisionResolutionChange, | |||
| isShowSingleRun, | |||
| @@ -143,6 +143,7 @@ export type Memory = { | |||
| enabled: boolean | |||
| size: number | string | null | |||
| } | |||
| query_prompt_template: string | |||
| } | |||
| export enum VarType { | |||
| @@ -204,6 +204,7 @@ const translation = { | |||
| singleRun: { | |||
| variable: 'Variable', | |||
| }, | |||
| sysQueryInUser: 'sys.query in user message is required', | |||
| }, | |||
| knowledgeRetrieval: { | |||
| queryVariable: 'Query Variable', | |||
| @@ -204,6 +204,7 @@ const translation = { | |||
| singleRun: { | |||
| variable: '变量', | |||
| }, | |||
| sysQueryInUser: 'user message 中必须包含 sys.query', | |||
| }, | |||
| knowledgeRetrieval: { | |||
| queryVariable: '查询变量', | |||