소스 검색

add keyword extraction in graph (#1373)

### What problem does this PR solve?
#918 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.8.0
KevinHuSh 1 년 전
부모
커밋
258c9ea644
No account linked to committer's email address

+ 7
- 4
api/apps/canvas_app.py 파일 보기

@@ -142,16 +142,19 @@ def run():


@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@validate_request("id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(user_canvas.dsl, current_user.id)
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return server_error_response("canvas not found.")

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)

+ 1
- 1
api/db/init_data.py 파일 보기

@@ -156,7 +156,7 @@ factory_infos = [{
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "Minimax",
"name": "MiniMax",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",

+ 13
- 5
graph/canvas.py 파일 보기

@@ -102,19 +102,26 @@ class Canvas(ABC):
self.load()

def load(self):
assert self.dsl.get("components", {}).get("begin"), "There have to be a 'Begin' component."

self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])

assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
assert "Answer" in cpn_nms, "There have to be an 'Answer' component."

for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
param.check()
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
if cpn["obj"].component_name == "Categorize":
for _,desc in param.category_description.items():
for _, desc in param.category_description.items():
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])


self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
@@ -140,7 +147,8 @@ class Canvas(ABC):
self.messages = []
self.answer = []
self.reference = []
self.components = {}
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""

def run(self, **kwargs):
@@ -176,7 +184,7 @@ class Canvas(ABC):
ran += 1

prepare2run(self.components[self.path[-2][-1]]["downstream"])
while ran < len(self.path[-1]):
while 0 <= ran < len(self.path[-1]):
if DEBUG: print(ran, self.path)
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)

+ 3
- 0
graph/component/base.py 파일 보기

@@ -418,6 +418,9 @@ class ComponentBase(ABC):
o = pd.DataFrame(o)
return self._param.output_var_name, o

def reset(self):
setattr(self._param, self._param.output_var_name, None)

def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)


+ 68
- 0
graph/component/keyword.py 파일 보기

@@ -0,0 +1,68 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG


class KeywordExtractParam(GenerateParam):

"""
Define the KeywordExtract component parameters.
"""
def __init__(self):
super().__init__()
self.temperature = 0.5
self.prompt = ""
self.topn = 1

def check(self):
super().check()

def get_prompt(self):
self.prompt = """
- Role: You're a question analyzer.
- Requirements:
- Summarize user's question, and give top %s important keyword/phrase.
- Use comma as a delimiter to separate keywords/phrases.
- Answer format: (in language of user's question)
- keyword:
"""%self.topn
return self.prompt


class KeywordExtract(Generate, ABC):
component_name = "RewriteQuestion"

def _run(self, history, **kwargs):
q = ""
for r, c in self._canvas.history[::-1]:
if r == "user":
q += c
break

chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
self._param.gen_conf())

ans = re.sub(r".*keyword:", "", ans).strip()
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::")
return KeywordExtract.be_output(ans)



+ 1
- 1
graph/templates/HR_callout_zh.json 파일 보기

@@ -1,5 +1,5 @@
{
"id": 0,
"id": 1,
"title": "HR call-out assistant(Chinese)",
"description": "A HR call-out assistant. It will introduce the given job, answer the candidates' question about this job. And the most important thing is that it will try to obtain the contact information of the candidates. What you need to do is to link a knowledgebase which contains job description in 'Retrieval' component.",
"canvas_type": "chatbot",

+ 5
- 5
graph/templates/customer_service.json 파일 보기

@@ -1,5 +1,5 @@
{
"id": 1,
"id": 2,
"title": "Customer service",
"description": "A call-in customer service chat bot. It will provide useful information about the products, answer customers' questions and soothe the customers' bad emotions.",
"canvas_type": "chatbot",
@@ -106,7 +106,7 @@
"upstream": ["categorize:0"]
},
"generate:complain": {
"downstream": [],
"downstream": ["answer:0"],
"obj": {
"component_name": "Generate",
"params": {
@@ -116,7 +116,7 @@
"prompt": "You are a customer support. the Customers complain even curse about the products but not specific enough. You need to ask him/her what's the specific problem with the product. Be nice, patient and concern to soothe your customers’ emotions at first place."
}
},
"upstream": ["categorize:0", "answer:0"]
"upstream": ["categorize:0"]
},
"message:get_contact": {
"downstream": ["answer:0"],
@@ -286,13 +286,13 @@
{
"id": "reactflow__edge-answer:0a-generate:complaind",
"markerEnd": "logo",
"source": "answer:0",
"source": "generate:complain",
"sourceHandle": "a",
"style": {
"stroke": "rgb(202 197 245)",
"strokeWidth": 2
},
"target": "generate:complain",
"target": "answer:0",
"targetHandle": "d",
"type": "buttonEdge"
},

+ 335
- 0
graph/templates/general_chat_bot.json
파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
파일 보기


+ 1
- 1
graph/templates/interpreter.json 파일 보기

@@ -1,5 +1,5 @@
{
"id": 2,
"id": 3,
"title": "Interpreter",
"description": "An interpreter. Type the content you want to translate and the object language like: Hi there => Spanish. Hava a try!",
"canvas_type": "chatbot",

+ 71
- 71
graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json 파일 보기

@@ -1,79 +1,79 @@
{
"components": {
"begin": {
"obj": {
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj": {
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj":{
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}
}

+ 1
- 0
rag/llm/chat_model.py 파일 보기

@@ -95,6 +95,7 @@ class DeepSeekChat(Base):
if not base_url: base_url="https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")

Loading…
취소
저장