浏览代码

Feat: Support passing knowledge base id as variable in retrieval component (#7088)

### What problem does this PR solve?

Fix #6600

Hello, I have the same business requirement as #6600. My use case is: 

We have many departments (> 20 now and increasing), and each department
has its own knowledge base. Because the agent workflow is the same, so I
want to change the knowledge base on the fly, instead of creating agents
for every department.

It now looks like this:


![屏幕截图_20250416_212622](https://github.com/user-attachments/assets/5cb3dade-d4fb-4591-ade3-4b9c54387911)

Knowledge bases can be selected from the dropdown, and passed through
the variables in the table. All selected knowledge bases are used for
retrieval.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
tags/v0.19.0
Song Fuchang 6 个月前
父节点
当前提交
6e7dd54a50
没有帐户链接到提交者的电子邮件

+ 46
- 32
agent/component/base.py 查看文件

@@ -19,7 +19,7 @@ import json
import os
import logging
from functools import partial
from typing import Tuple, Union
from typing import Any, Tuple, Union

import pandas as pd

@@ -462,6 +462,33 @@ class ComponentBase(ABC):
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)

def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
outs = []
for q in sources:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue

if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
elif q.get("value"):
outs.append(pd.DataFrame([{"content": q["value"]}]))
return outs

def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
@@ -475,37 +502,24 @@ class ComponentBase(ABC):

if self._param.query:
self._param.inputs = []
outs = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
self._param.inputs.append({"component_id": q["component_id"],
"content": p.get("value", "")})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue

if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
self._param.inputs.append({"component_id": q["component_id"],
"content": "\n".join(
[str(d["content"]) for d in outs[-1].to_dict('records')])})
elif q.get("value"):
self._param.inputs.append({"component_id": None, "content": q["value"]})
outs.append(pd.DataFrame([{"content": q["value"]}]))
outs = self._fetch_outputs_from(self._param.query)

for out in outs:
records = out.to_dict("records")
content: str

if len(records) > 1:
content = "\n".join(
[str(d["content"]) for d in records]
)
else:
content = records[0]["content"]

self._param.inputs.append({
"component_id": records[0].get("component_id"),
"content": content
})

if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df:

+ 19
- 3
agent/component/retrieval.py 查看文件

@@ -41,6 +41,7 @@ class RetrievalParam(ComponentParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
self.tavily_api_key = ""
@@ -58,7 +59,22 @@ class Retrieval(ComponentBase, ABC):
def _run(self, history, **kwargs):
query = self.get_input()
query = str(query["content"][0]) if "content" in query else ""
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)

kb_ids: list[str] = self._param.kb_ids or []

kb_vars = self._fetch_outputs_from(self._param.kb_vars)

if len(kb_vars) > 0:
for kb_var in kb_vars:
if len(kb_var) == 1:
kb_ids.append(str(kb_var["content"][0]))
else:
for v in kb_var.to_dict("records"):
kb_ids.append(v["content"])

filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]

kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs:
return Retrieval.be_output("")

@@ -75,7 +91,7 @@ class Retrieval(ComponentBase, ABC):
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)

if kbs:
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl,
@@ -86,7 +102,7 @@ class Retrieval(ComponentBase, ABC):
if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query,
[kbs[0].tenant_id],
self._param.kb_ids,
filtered_kb_ids,
embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:

+ 3
- 1
web/src/components/knowledge-base-item.tsx 查看文件

@@ -10,11 +10,13 @@ import { FormControl, FormField, FormItem, FormLabel } from './ui/form';
import { MultiSelect } from './ui/multi-select';

interface KnowledgeBaseItemProps {
tooltipText?: string;
required?: boolean;
onChange?(): void;
}

const KnowledgeBaseItem = ({
tooltipText,
required = true,
onChange,
}: KnowledgeBaseItemProps) => {
@@ -40,7 +42,7 @@ const KnowledgeBaseItem = ({
<Form.Item
label={t('knowledgeBases')}
name="kb_ids"
tooltip={t('knowledgeBasesTip')}
tooltip={tooltipText || t('knowledgeBasesTip')}
rules={[
{
required,

+ 3
- 0
web/src/locales/en.ts 查看文件

@@ -1255,6 +1255,9 @@ This delimiter is used to split the input text into several text pieces echo of
promptTip:
'Use the system prompt to describe the task for the LLM, specify how it should respond, and outline other miscellaneous requirements. The system prompt is often used in conjunction with keys (variables), which serve as various data inputs for the LLM. Use a forward slash `/` or the (x) button to show the keys to use.',
promptMessage: 'Prompt is required',
knowledgeBasesTip:
'Select the knowledge bases to associate with this chat assistant, or choose variables containing knowledge base IDs below.',
knowledgeBaseVars: 'Knowledge base variables',
},
},
};

+ 2
- 0
web/src/locales/zh.ts 查看文件

@@ -1210,6 +1210,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
promptMessage: '提示词是必填项',
promptTip:
'系统提示为大模型提供任务描述、规定回复方式,以及设置其他各种要求。系统提示通常与 key (变量)合用,通过变量设置大模型的输入数据。你可以通过斜杠或者 (x) 按钮显示可用的 key。',
knowledgeBasesTip: '选择关联的知识库,或者在下方选择包含知识库ID的变量。',
knowledgeBaseVars: '知识库变量',
},
footer: {
profile: 'All rights reserved @ React',

+ 11
- 8
web/src/pages/flow/form/components/dynamic-input-variable.tsx 查看文件

@@ -8,7 +8,9 @@ import { useBuildComponentIdSelectOptions } from '../../hooks/use-get-begin-quer
import styles from './index.less';

interface IProps {
name?: string;
node?: RAGFlowNodeType;
title?: string;
}

enum VariableType {
@@ -19,7 +21,8 @@ enum VariableType {
const getVariableName = (type: string) =>
type === VariableType.Reference ? 'component_id' : 'value';

const DynamicVariableForm = ({ node }: IProps) => {
const DynamicVariableForm = ({ name: formName, node }: IProps) => {
formName = formName || 'query';
const { t } = useTranslation();
const valueOptions = useBuildComponentIdSelectOptions(
node?.id,
@@ -35,15 +38,15 @@ const DynamicVariableForm = ({ node }: IProps) => {
const handleTypeChange = useCallback(
(name: number) => () => {
setTimeout(() => {
form.setFieldValue(['query', name, 'component_id'], undefined);
form.setFieldValue(['query', name, 'value'], undefined);
form.setFieldValue([formName, name, 'component_id'], undefined);
form.setFieldValue([formName, name, 'value'], undefined);
}, 0);
},
[form],
);

return (
<Form.List name="query">
<Form.List name={formName}>
{(fields, { add, remove }) => (
<>
{fields.map(({ key, name, ...restField }) => (
@@ -60,7 +63,7 @@ const DynamicVariableForm = ({ node }: IProps) => {
</Form.Item>
<Form.Item noStyle dependencies={[name, 'type']}>
{({ getFieldValue }) => {
const type = getFieldValue(['query', name, 'type']);
const type = getFieldValue([formName, name, 'type']);
return (
<Form.Item
{...restField}
@@ -118,11 +121,11 @@ export function FormCollapse({
);
}

const DynamicInputVariable = ({ node }: IProps) => {
const DynamicInputVariable = ({ name, node, title }: IProps) => {
const { t } = useTranslation();
return (
<FormCollapse title={t('flow.input')}>
<DynamicVariableForm node={node}></DynamicVariableForm>
<FormCollapse title={title || t('flow.input')}>
<DynamicVariableForm name={name} node={node}></DynamicVariableForm>
</FormCollapse>
);
};

+ 8
- 1
web/src/pages/flow/form/retrieval-form/index.tsx 查看文件

@@ -43,7 +43,14 @@ const RetrievalForm = ({ onValuesChange, form, node }: IOperatorForm) => {
<Rerank></Rerank>
<TavilyItem name={'tavily_api_key'}></TavilyItem>
<UseKnowledgeGraphItem filedName={'use_kg'}></UseKnowledgeGraphItem>
<KnowledgeBaseItem></KnowledgeBaseItem>
<KnowledgeBaseItem
tooltipText={t('knowledgeBasesTip')}
></KnowledgeBaseItem>
<DynamicInputVariable
name={'kb_vars'}
node={node}
title={t('knowledgeBaseVars')}
></DynamicInputVariable>
<Form.Item
name={'empty_response'}
label={t('emptyResponse', { keyPrefix: 'chat' })}

正在加载...
取消
保存