瀏覽代碼

add keyword extraction in graph (#1373)

### What problem does this PR solve?
#918 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.8.0
KevinHuSh 1 年之前
父節點
當前提交
258c9ea644
沒有連結到貢獻者的電子郵件帳戶。

+ 7
- 4
api/apps/canvas_app.py 查看文件

@@ -142,16 +142,19 @@ def run():


@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@validate_request("id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(user_canvas.dsl, current_user.id)
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return server_error_response("canvas not found.")

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)

+ 1
- 1
api/db/init_data.py 查看文件

@@ -156,7 +156,7 @@ factory_infos = [{
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "Minimax",
"name": "MiniMax",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",

+ 13
- 5
graph/canvas.py 查看文件

@@ -102,19 +102,26 @@ class Canvas(ABC):
self.load()

def load(self):
assert self.dsl.get("components", {}).get("begin"), "There have to be a 'Begin' component."

self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])

assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
assert "Answer" in cpn_nms, "There have to be an 'Answer' component."

for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
param.check()
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
if cpn["obj"].component_name == "Categorize":
for _,desc in param.category_description.items():
for _, desc in param.category_description.items():
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])


self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
@@ -140,7 +147,8 @@ class Canvas(ABC):
self.messages = []
self.answer = []
self.reference = []
self.components = {}
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""

def run(self, **kwargs):
@@ -176,7 +184,7 @@ class Canvas(ABC):
ran += 1

prepare2run(self.components[self.path[-2][-1]]["downstream"])
while ran < len(self.path[-1]):
while 0 <= ran < len(self.path[-1]):
if DEBUG: print(ran, self.path)
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)

+ 3
- 0
graph/component/base.py 查看文件

@@ -418,6 +418,9 @@ class ComponentBase(ABC):
o = pd.DataFrame(o)
return self._param.output_var_name, o

def reset(self):
setattr(self._param, self._param.output_var_name, None)

def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)


+ 68
- 0
graph/component/keyword.py 查看文件

@@ -0,0 +1,68 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG


class KeywordExtractParam(GenerateParam):

"""
Define the KeywordExtract component parameters.
"""
def __init__(self):
super().__init__()
self.temperature = 0.5
self.prompt = ""
self.topn = 1

def check(self):
super().check()

def get_prompt(self):
self.prompt = """
- Role: You're a question analyzer.
- Requirements:
- Summarize user's question, and give top %s important keyword/phrase.
- Use comma as a delimiter to separate keywords/phrases.
- Answer format: (in language of user's question)
- keyword:
"""%self.topn
return self.prompt


class KeywordExtract(Generate, ABC):
component_name = "RewriteQuestion"

def _run(self, history, **kwargs):
q = ""
for r, c in self._canvas.history[::-1]:
if r == "user":
q += c
break

chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
self._param.gen_conf())

ans = re.sub(r".*keyword:", "", ans).strip()
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::")
return KeywordExtract.be_output(ans)



+ 1
- 1
graph/templates/HR_callout_zh.json 查看文件

@@ -1,5 +1,5 @@
{
"id": 0,
"id": 1,
"title": "HR call-out assistant(Chinese)",
"description": "A HR call-out assistant. It will introduce the given job, answer the candidates' question about this job. And the most important thing is that it will try to obtain the contact information of the candidates. What you need to do is to link a knowledgebase which contains job description in 'Retrieval' component.",
"canvas_type": "chatbot",

+ 5
- 5
graph/templates/customer_service.json 查看文件

@@ -1,5 +1,5 @@
{
"id": 1,
"id": 2,
"title": "Customer service",
"description": "A call-in customer service chat bot. It will provide useful information about the products, answer customers' questions and soothe the customers' bad emotions.",
"canvas_type": "chatbot",
@@ -106,7 +106,7 @@
"upstream": ["categorize:0"]
},
"generate:complain": {
"downstream": [],
"downstream": ["answer:0"],
"obj": {
"component_name": "Generate",
"params": {
@@ -116,7 +116,7 @@
"prompt": "You are a customer support. the Customers complain even curse about the products but not specific enough. You need to ask him/her what's the specific problem with the product. Be nice, patient and concern to soothe your customers’ emotions at first place."
}
},
"upstream": ["categorize:0", "answer:0"]
"upstream": ["categorize:0"]
},
"message:get_contact": {
"downstream": ["answer:0"],
@@ -286,13 +286,13 @@
{
"id": "reactflow__edge-answer:0a-generate:complaind",
"markerEnd": "logo",
"source": "answer:0",
"source": "generate:complain",
"sourceHandle": "a",
"style": {
"stroke": "rgb(202 197 245)",
"strokeWidth": 2
},
"target": "generate:complain",
"target": "answer:0",
"targetHandle": "d",
"type": "buttonEdge"
},

+ 335
- 0
graph/templates/general_chat_bot.json
文件差異過大導致無法顯示
查看文件


+ 1
- 1
graph/templates/interpreter.json 查看文件

@@ -1,5 +1,5 @@
{
"id": 2,
"id": 3,
"title": "Interpreter",
"description": "An interpreter. Type the content you want to translate and the object language like: Hi there => Spanish. Hava a try!",
"canvas_type": "chatbot",

+ 71
- 71
graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json 查看文件

@@ -1,79 +1,79 @@
{
"components": {
"begin": {
"obj": {
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj": {
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj":{
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}
}

+ 1
- 0
rag/llm/chat_model.py 查看文件

@@ -95,6 +95,7 @@ class DeepSeekChat(Base):
if not base_url: base_url="https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")

Loading…
取消
儲存