浏览代码

Feat: Support passing knowledge base id as variable in retrieval component (#7088)

### What problem does this PR solve?

Fix #6600

Hello, I have the same business requirement as #6600. My use case is: 

We have many departments (> 20 now and increasing), and each department
has its own knowledge base. Because the agent workflow is the same, so I
want to change the knowledge base on the fly, instead of creating agents
for every department.

It now looks like this:


![屏幕截图_20250416_212622](https://github.com/user-attachments/assets/5cb3dade-d4fb-4591-ade3-4b9c54387911)

Knowledge bases can be selected from the dropdown, and passed through
the variables in the table. All selected knowledge bases are used for
retrieval.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
tags/v0.19.0
Song Fuchang 6 个月前
父节点
当前提交
6e7dd54a50
没有帐户链接到提交者的电子邮件

+ 46
- 32
agent/component/base.py 查看文件

import os import os
import logging import logging
from functools import partial from functools import partial
from typing import Tuple, Union
from typing import Any, Tuple, Union


import pandas as pd import pandas as pd


def set_output(self, v): def set_output(self, v):
setattr(self._param, self._param.output_var_name, v) setattr(self._param, self._param.output_var_name, v)


def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
outs = []
for q in sources:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue

if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
elif q.get("value"):
outs.append(pd.DataFrame([{"content": q["value"]}]))
return outs

def get_input(self): def get_input(self):
if self._param.debug_inputs: if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")]) return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])


if self._param.query: if self._param.query:
self._param.inputs = [] self._param.inputs = []
outs = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"],
"content": p.get("value", "")})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue

if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
self._param.inputs.append({"component_id": q["component_id"],
"content": "\n".join(
[str(d["content"]) for d in outs[-1].to_dict('records')])})
elif q.get("value"):
self._param.inputs.append({"component_id": None, "content": q["value"]})
outs.append(pd.DataFrame([{"content": q["value"]}]))
outs = self._fetch_outputs_from(self._param.query)

for out in outs:
records = out.to_dict("records")
content: str

if len(records) > 1:
content = "\n".join(
[str(d["content"]) for d in records]
)
else:
content = records[0]["content"]

self._param.inputs.append({
"component_id": records[0].get("component_id"),
"content": content
})

if outs: if outs:
df = pd.concat(outs, ignore_index=True) df = pd.concat(outs, ignore_index=True)
if "content" in df: if "content" in df:

+ 19
- 3
agent/component/retrieval.py 查看文件

self.top_n = 8 self.top_n = 8
self.top_k = 1024 self.top_k = 1024
self.kb_ids = [] self.kb_ids = []
self.kb_vars = []
self.rerank_id = "" self.rerank_id = ""
self.empty_response = "" self.empty_response = ""
self.tavily_api_key = "" self.tavily_api_key = ""
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
query = self.get_input() query = self.get_input()
query = str(query["content"][0]) if "content" in query else "" query = str(query["content"][0]) if "content" in query else ""
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)

kb_ids: list[str] = self._param.kb_ids or []

kb_vars = self._fetch_outputs_from(self._param.kb_vars)

if len(kb_vars) > 0:
for kb_var in kb_vars:
if len(kb_var) == 1:
kb_ids.append(str(kb_var["content"][0]))
else:
for v in kb_var.to_dict("records"):
kb_ids.append(v["content"])

filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]

kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs: if not kbs:
return Retrieval.be_output("") return Retrieval.be_output("")


rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)


if kbs: if kbs:
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids,
1, self._param.top_n, 1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl, aggs=False, rerank_mdl=rerank_mdl,
if self._param.use_kg and kbs: if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query, ck = settings.kg_retrievaler.retrieval(query,
[kbs[0].tenant_id], [kbs[0].tenant_id],
self._param.kb_ids,
filtered_kb_ids,
embd_mdl, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:

+ 3
- 1
web/src/components/knowledge-base-item.tsx 查看文件

import { MultiSelect } from './ui/multi-select'; import { MultiSelect } from './ui/multi-select';


interface KnowledgeBaseItemProps { interface KnowledgeBaseItemProps {
tooltipText?: string;
required?: boolean; required?: boolean;
onChange?(): void; onChange?(): void;
} }


const KnowledgeBaseItem = ({ const KnowledgeBaseItem = ({
tooltipText,
required = true, required = true,
onChange, onChange,
}: KnowledgeBaseItemProps) => { }: KnowledgeBaseItemProps) => {
<Form.Item <Form.Item
label={t('knowledgeBases')} label={t('knowledgeBases')}
name="kb_ids" name="kb_ids"
tooltip={t('knowledgeBasesTip')}
tooltip={tooltipText || t('knowledgeBasesTip')}
rules={[ rules={[
{ {
required, required,

+ 3
- 0
web/src/locales/en.ts 查看文件

promptTip: promptTip:
'Use the system prompt to describe the task for the LLM, specify how it should respond, and outline other miscellaneous requirements. The system prompt is often used in conjunction with keys (variables), which serve as various data inputs for the LLM. Use a forward slash `/` or the (x) button to show the keys to use.', 'Use the system prompt to describe the task for the LLM, specify how it should respond, and outline other miscellaneous requirements. The system prompt is often used in conjunction with keys (variables), which serve as various data inputs for the LLM. Use a forward slash `/` or the (x) button to show the keys to use.',
promptMessage: 'Prompt is required', promptMessage: 'Prompt is required',
knowledgeBasesTip:
'Select the knowledge bases to associate with this chat assistant, or choose variables containing knowledge base IDs below.',
knowledgeBaseVars: 'Knowledge base variables',
}, },
}, },
}; };

+ 2
- 0
web/src/locales/zh.ts 查看文件

promptMessage: '提示词是必填项', promptMessage: '提示词是必填项',
promptTip: promptTip:
'系统提示为大模型提供任务描述、规定回复方式,以及设置其他各种要求。系统提示通常与 key (变量)合用,通过变量设置大模型的输入数据。你可以通过斜杠或者 (x) 按钮显示可用的 key。', '系统提示为大模型提供任务描述、规定回复方式,以及设置其他各种要求。系统提示通常与 key (变量)合用,通过变量设置大模型的输入数据。你可以通过斜杠或者 (x) 按钮显示可用的 key。',
knowledgeBasesTip: '选择关联的知识库,或者在下方选择包含知识库ID的变量。',
knowledgeBaseVars: '知识库变量',
}, },
footer: { footer: {
profile: 'All rights reserved @ React', profile: 'All rights reserved @ React',

+ 11
- 8
web/src/pages/flow/form/components/dynamic-input-variable.tsx 查看文件

import styles from './index.less'; import styles from './index.less';


interface IProps { interface IProps {
name?: string;
node?: RAGFlowNodeType; node?: RAGFlowNodeType;
title?: string;
} }


enum VariableType { enum VariableType {
const getVariableName = (type: string) => const getVariableName = (type: string) =>
type === VariableType.Reference ? 'component_id' : 'value'; type === VariableType.Reference ? 'component_id' : 'value';


const DynamicVariableForm = ({ node }: IProps) => {
const DynamicVariableForm = ({ name: formName, node }: IProps) => {
formName = formName || 'query';
const { t } = useTranslation(); const { t } = useTranslation();
const valueOptions = useBuildComponentIdSelectOptions( const valueOptions = useBuildComponentIdSelectOptions(
node?.id, node?.id,
const handleTypeChange = useCallback( const handleTypeChange = useCallback(
(name: number) => () => { (name: number) => () => {
setTimeout(() => { setTimeout(() => {
form.setFieldValue(['query', name, 'component_id'], undefined);
form.setFieldValue(['query', name, 'value'], undefined);
form.setFieldValue([formName, name, 'component_id'], undefined);
form.setFieldValue([formName, name, 'value'], undefined);
}, 0); }, 0);
}, },
[form], [form],
); );


return ( return (
<Form.List name="query">
<Form.List name={formName}>
{(fields, { add, remove }) => ( {(fields, { add, remove }) => (
<> <>
{fields.map(({ key, name, ...restField }) => ( {fields.map(({ key, name, ...restField }) => (
</Form.Item> </Form.Item>
<Form.Item noStyle dependencies={[name, 'type']}> <Form.Item noStyle dependencies={[name, 'type']}>
{({ getFieldValue }) => { {({ getFieldValue }) => {
const type = getFieldValue(['query', name, 'type']);
const type = getFieldValue([formName, name, 'type']);
return ( return (
<Form.Item <Form.Item
{...restField} {...restField}
); );
} }


const DynamicInputVariable = ({ node }: IProps) => {
const DynamicInputVariable = ({ name, node, title }: IProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<FormCollapse title={t('flow.input')}>
<DynamicVariableForm node={node}></DynamicVariableForm>
<FormCollapse title={title || t('flow.input')}>
<DynamicVariableForm name={name} node={node}></DynamicVariableForm>
</FormCollapse> </FormCollapse>
); );
}; };

+ 8
- 1
web/src/pages/flow/form/retrieval-form/index.tsx 查看文件

<Rerank></Rerank> <Rerank></Rerank>
<TavilyItem name={'tavily_api_key'}></TavilyItem> <TavilyItem name={'tavily_api_key'}></TavilyItem>
<UseKnowledgeGraphItem filedName={'use_kg'}></UseKnowledgeGraphItem> <UseKnowledgeGraphItem filedName={'use_kg'}></UseKnowledgeGraphItem>
<KnowledgeBaseItem></KnowledgeBaseItem>
<KnowledgeBaseItem
tooltipText={t('knowledgeBasesTip')}
></KnowledgeBaseItem>
<DynamicInputVariable
name={'kb_vars'}
node={node}
title={t('knowledgeBaseVars')}
></DynamicInputVariable>
<Form.Item <Form.Item
name={'empty_response'} name={'empty_response'}
label={t('emptyResponse', { keyPrefix: 'chat' })} label={t('emptyResponse', { keyPrefix: 'chat' })}

正在加载...
取消
保存