Browse Source

search between multiple indiices for team function (#3079)

### What problem does this PR solve?

#2834 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.13.0
Kevin Hu 1 year ago
parent
commit
2d1fbefdb5
No account linked to committer's email address

+ 1
- 0
agent/component/__init__.py View File

from .tushare import TuShare, TuShareParam from .tushare import TuShare, TuShareParam
from .akshare import AkShare, AkShareParam from .akshare import AkShare, AkShareParam
from .crawler import Crawler, CrawlerParam from .crawler import Crawler, CrawlerParam
from .invoke import Invoke, InvokeParam




def component_class(class_name): def component_class(class_name):

+ 9
- 5
agent/component/generate.py View File

from functools import partial from functools import partial
import pandas as pd import pandas as pd
from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.settings import retrievaler
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase


kwargs["input"] = input kwargs["input"] = input
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)


downstreams = self._canvas.get_component(self._id)["downstream"] downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res]) return pd.DataFrame([res])


ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf())
msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())

if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans) res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res]) return pd.DataFrame([res])
self.set_output(res) self.set_output(res)
return return


msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf()):
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []} res = {"content": ans, "reference": []}
answer = ans answer = ans
yield res yield res

+ 22
- 3
agent/component/invoke.py View File

# limitations under the License. # limitations under the License.
# #
import json import json
import re
from abc import ABC from abc import ABC

import requests import requests
from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase




self.variables = [] self.variables = []
self.url = "" self.url = ""
self.timeout = 60 self.timeout = 60
self.clean_html = False


def check(self): def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put']) self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL") self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second") self.check_positive_integer(self.timeout, "Timeout time in second")
self.check_boolean(self.clean_html, "Clean HTML")




class Invoke(ComponentBase, ABC): class Invoke(ComponentBase, ABC):
if self._param.headers: if self._param.headers:
headers = json.loads(self._param.headers) headers = json.loads(self._param.headers)
proxies = None proxies = None
if self._param.proxy:
if re.sub(r"https?:?/?/?", "", self._param.proxy):
proxies = {"http": self._param.proxy, "https": self._param.proxy} proxies = {"http": self._param.proxy, "https": self._param.proxy}


if method == 'get': if method == 'get':
headers=headers, headers=headers,
proxies=proxies, proxies=proxies,
timeout=self._param.timeout) timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))

return Invoke.be_output(response.text) return Invoke.be_output(response.text)


if method == 'put': if method == 'put':
headers=headers, headers=headers,
proxies=proxies, proxies=proxies,
timeout=self._param.timeout) timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)


if method == 'post':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text) return Invoke.be_output(response.text)

+ 3
- 1
api/db/services/dialog_service.py View File

else: else:
if prompt_config.get("keyword", False): if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) questions[-1] += keyword_extraction(chat_mdl, questions[-1])
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,

tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, dialog.vector_similarity_weight,
doc_ids=attachments, doc_ids=attachments,

+ 2
- 0
deepdoc/parser/html_parser.py View File

import html_text import html_text
import chardet import chardet



def get_encoding(file): def get_encoding(file):
with open(file,'rb') as f: with open(file,'rb') as f:
tmp = chardet.detect(f.read()) tmp = chardet.detect(f.read())
return tmp['encoding'] return tmp['encoding']



class RAGFlowHtmlParser: class RAGFlowHtmlParser:
def __call__(self, fnm, binary=None): def __call__(self, fnm, binary=None):
txt = "" txt = ""

+ 12
- 6
rag/nlp/search.py View File

Q("bool", must_not=Q("range", available_int={"lt": 1}))) Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry return bqry


def search(self, req, idxnm, emb_mdl=None, highlight=False):
def search(self, req, idxnms, emb_mdl=None, highlight=False):
qst = req.get("question", "") qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst, min_match="30%") bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req) bqry = self._add_filters(bqry, req)
del s["highlight"] del s["highlight"]
q_vec = s["knn"]["query_vector"] q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s))) es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s: if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%") bqry, _ = self.qryr.question(qst, min_match="10%")
s["query"] = bqry.to_dict() s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17 s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s))) es_logger.info("【Q】: {}".format(json.dumps(s)))


kwds = set([]) kwds = set([])
rag_tokenizer.tokenize(ans).split(" "), rag_tokenizer.tokenize(ans).split(" "),
rag_tokenizer.tokenize(inst).split(" ")) rag_tokenizer.tokenize(inst).split(" "))


def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
return ranks return ranks

RERANK_PAGE_LIMIT = 3 RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128), req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold, "similarity": similarity_threshold,
"available_int": 1} "available_int": 1}

if page > RERANK_PAGE_LIMIT: if page > RERANK_PAGE_LIMIT:
req["page"] = page req["page"] = page
req["size"] = page_size req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)

if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")

sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
ranks["total"] = sres.total ranks["total"] = sres.total


if page <= RERANK_PAGE_LIMIT: if page <= RERANK_PAGE_LIMIT:
s = Search() s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count] s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict() s = s.to_dict()
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = [] res = []
for index, chunk in enumerate(es_res['hits']['hits']): for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields}) res.append({fld: chunk['_source'].get(fld) for fld in fields})

+ 4
- 2
rag/utils/es_conn.py View File



return False return False


def search(self, q, idxnm=None, src=False, timeout="2s"):
def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict): if not isinstance(q, dict):
q = Search().query(q).to_dict() q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3): for i in range(3):
try: try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q, body=q,
timeout=timeout, timeout=timeout,
# search_type="dfs_query_then_fetch", # search_type="dfs_query_then_fetch",

Loading…
Cancel
Save