Selaa lähdekoodia

add keyword extraction in graph (#1373)

### What problem does this PR solve?
#918 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.8.0
KevinHuSh 1 vuosi sitten
vanhempi
commit
258c9ea644
No account linked to committer's email address

+ 7
- 4
api/apps/canvas_app.py Näytä tiedosto

@@ -142,16 +142,19 @@ def run():


@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@validate_request("id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(user_canvas.dsl, current_user.id)
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return server_error_response("canvas not found.")

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)

+ 1
- 1
api/db/init_data.py Näytä tiedosto

@@ -156,7 +156,7 @@ factory_infos = [{
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "Minimax",
"name": "MiniMax",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",

+ 13
- 5
graph/canvas.py Näytä tiedosto

@@ -102,19 +102,26 @@ class Canvas(ABC):
self.load()

def load(self):
assert self.dsl.get("components", {}).get("begin"), "There have to be a 'Begin' component."

self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])

assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
assert "Answer" in cpn_nms, "There have to be an 'Answer' component."

for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
param.check()
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
if cpn["obj"].component_name == "Categorize":
for _,desc in param.category_description.items():
for _, desc in param.category_description.items():
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])


self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
@@ -140,7 +147,8 @@ class Canvas(ABC):
self.messages = []
self.answer = []
self.reference = []
self.components = {}
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""

def run(self, **kwargs):
@@ -176,7 +184,7 @@ class Canvas(ABC):
ran += 1

prepare2run(self.components[self.path[-2][-1]]["downstream"])
while ran < len(self.path[-1]):
while 0 <= ran < len(self.path[-1]):
if DEBUG: print(ran, self.path)
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)

+ 3
- 0
graph/component/base.py Näytä tiedosto

@@ -418,6 +418,9 @@ class ComponentBase(ABC):
o = pd.DataFrame(o)
return self._param.output_var_name, o

def reset(self):
setattr(self._param, self._param.output_var_name, None)

def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)


+ 68
- 0
graph/component/keyword.py Näytä tiedosto

@@ -0,0 +1,68 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG


class KeywordExtractParam(GenerateParam):

"""
Define the KeywordExtract component parameters.
"""
def __init__(self):
super().__init__()
self.temperature = 0.5
self.prompt = ""
self.topn = 1

def check(self):
super().check()

def get_prompt(self):
self.prompt = """
- Role: You're a question analyzer.
- Requirements:
- Summarize user's question, and give top %s important keyword/phrase.
- Use comma as a delimiter to separate keywords/phrases.
- Answer format: (in language of user's question)
- keyword:
"""%self.topn
return self.prompt


class KeywordExtract(Generate, ABC):
component_name = "RewriteQuestion"

def _run(self, history, **kwargs):
q = ""
for r, c in self._canvas.history[::-1]:
if r == "user":
q += c
break

chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
self._param.gen_conf())

ans = re.sub(r".*keyword:", "", ans).strip()
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::")
return KeywordExtract.be_output(ans)



+ 1
- 1
graph/templates/HR_callout_zh.json Näytä tiedosto

@@ -1,5 +1,5 @@
{
"id": 0,
"id": 1,
"title": "HR call-out assistant(Chinese)",
"description": "A HR call-out assistant. It will introduce the given job, answer the candidates' question about this job. And the most important thing is that it will try to obtain the contact information of the candidates. What you need to do is to link a knowledgebase which contains job description in 'Retrieval' component.",
"canvas_type": "chatbot",

+ 5
- 5
graph/templates/customer_service.json Näytä tiedosto

@@ -1,5 +1,5 @@
{
"id": 1,
"id": 2,
"title": "Customer service",
"description": "A call-in customer service chat bot. It will provide useful information about the products, answer customers' questions and soothe the customers' bad emotions.",
"canvas_type": "chatbot",
@@ -106,7 +106,7 @@
"upstream": ["categorize:0"]
},
"generate:complain": {
"downstream": [],
"downstream": ["answer:0"],
"obj": {
"component_name": "Generate",
"params": {
@@ -116,7 +116,7 @@
"prompt": "You are a customer support. the Customers complain even curse about the products but not specific enough. You need to ask him/her what's the specific problem with the product. Be nice, patient and concern to soothe your customers’ emotions at first place."
}
},
"upstream": ["categorize:0", "answer:0"]
"upstream": ["categorize:0"]
},
"message:get_contact": {
"downstream": ["answer:0"],
@@ -286,13 +286,13 @@
{
"id": "reactflow__edge-answer:0a-generate:complaind",
"markerEnd": "logo",
"source": "answer:0",
"source": "generate:complain",
"sourceHandle": "a",
"style": {
"stroke": "rgb(202 197 245)",
"strokeWidth": 2
},
"target": "generate:complain",
"target": "answer:0",
"targetHandle": "d",
"type": "buttonEdge"
},

+ 335
- 0
graph/templates/general_chat_bot.json
File diff suppressed because it is too large
Näytä tiedosto


+ 1
- 1
graph/templates/interpreter.json Näytä tiedosto

@@ -1,5 +1,5 @@
{
"id": 2,
"id": 3,
"title": "Interpreter",
"description": "An interpreter. Type the content you want to translate and the object language like: Hi there => Spanish. Hava a try!",
"canvas_type": "chatbot",

+ 71
- 71
graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json Näytä tiedosto

@@ -1,79 +1,79 @@
{
"components": {
"begin": {
"obj": {
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj": {
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj":{
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}
}

+ 1
- 0
rag/llm/chat_model.py Näytä tiedosto

@@ -95,6 +95,7 @@ class DeepSeekChat(Base):
if not base_url: base_url="https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")

Loading…
Peruuta
Tallenna