| if not llm: | if not llm: | ||||
| raise LookupError("LLM(%s) not found" % dialog.llm_id) | raise LookupError("LLM(%s) not found" % dialog.llm_id) | ||||
| llm = llm[0] | llm = llm[0] | ||||
| question = messages[-1]["content"] | |||||
| questions = [m["content"] for m in messages if m["role"] == "user"] | |||||
| embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) | embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) | ||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | ||||
| ## try to use sql if field mapping is good to go | ## try to use sql if field mapping is good to go | ||||
| if field_map: | if field_map: | ||||
| stat_logger.info("Use SQL to retrieval.") | stat_logger.info("Use SQL to retrieval.") | ||||
| markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl) | |||||
| markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl) | |||||
| if markdown_tbl: | if markdown_tbl: | ||||
| return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}} | return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}} | ||||
| if p["key"] not in kwargs: | if p["key"] not in kwargs: | ||||
| prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") | prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") | ||||
| kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | |||||
| for _ in range(len(questions)//2): | |||||
| questions.append(questions[-1]) | |||||
| kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | |||||
| dialog.similarity_threshold, | dialog.similarity_threshold, | ||||
| dialog.vector_similarity_weight, top=1024, aggs=False) | dialog.vector_similarity_weight, top=1024, aggs=False) | ||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | ||||
| def use_sql(question, field_map, tenant_id, chat_mdl): | def use_sql(question, field_map, tenant_id, chat_mdl): | ||||
| sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。" | |||||
| sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" | |||||
| user_promt = """ | user_promt = """ | ||||
| 表名:{}; | 表名:{}; | ||||
| 数据库表字段说明如下: | 数据库表字段说明如下: | ||||
| {} | {} | ||||
| 问题:{} | |||||
| 问题如下: | |||||
| {} | |||||
| 请写出SQL,且只要SQL,不要有其他说明及文字。 | 请写出SQL,且只要SQL,不要有其他说明及文字。 | ||||
| """.format( | """.format( | ||||
| index_name(tenant_id), | index_name(tenant_id), |
| if len(users) > 1: raise Exception('Same E-mail exist!') | if len(users) > 1: raise Exception('Same E-mail exist!') | ||||
| user = users[0] | user = users[0] | ||||
| login_user(user) | login_user(user) | ||||
| return redirect("/?auth=%s"%user.get_id()) | |||||
| except Exception as e: | except Exception as e: | ||||
| rollback_user_registration(user_id) | rollback_user_registration(user_id) | ||||
| stat_logger.exception(e) | stat_logger.exception(e) | ||||
| return redirect("/?error=%s"%str(e)) | return redirect("/?error=%s"%str(e)) | ||||
| return redirect("/?auth=%s"%user_id) | |||||
| user = users[0] | |||||
| login_user(user) | |||||
| return redirect("/?auth=%s" % user.get_id()) | |||||
| def user_info_from_github(access_token): | def user_info_from_github(access_token): |
| images, outputs = init_in_out(args) | images, outputs = init_in_out(args) | ||||
| if args.mode.lower() == "layout": | if args.mode.lower() == "layout": | ||||
| labels = LayoutRecognizer.labels | labels = LayoutRecognizer.labels | ||||
| detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| if args.mode.lower() == "tsr": | if args.mode.lower() == "tsr": | ||||
| labels = TableStructureRecognizer.labels | labels = TableStructureRecognizer.labels | ||||
| detr = TableStructureRecognizer() | detr = TableStructureRecognizer() |
| return res | return res | ||||
| def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): | |||||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||||
| """ | """ | ||||
| The supported file formats are pdf, pptx. | The supported file formats are pdf, pptx. | ||||
| Every page will be treated as a chunk. And the thumbnail of every page will be stored. | Every page will be treated as a chunk. And the thumbnail of every page will be stored. | ||||
| PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary. | PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary. | ||||
| """ | """ | ||||
| eng = lang.lower() == "english" | |||||
| doc = { | doc = { | ||||
| "docnm_kwd": filename, | "docnm_kwd": filename, | ||||
| "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | ||||
| for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)): | for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)): | ||||
| d = copy.deepcopy(doc) | d = copy.deepcopy(doc) | ||||
| d["image"] = img | d["image"] = img | ||||
| d["page_num_obj"] = [pn+1] | |||||
| tokenize(d, txt, pdf_parser.is_english) | |||||
| d["page_num_int"] = [pn+1] | |||||
| d["top_int"] = [0] | |||||
| d["position_int"].append((pn + 1, 0, img.size[0], 0, img.size[1])) | |||||
| tokenize(d, txt, eng) | |||||
| res.append(d) | res.append(d) | ||||
| return res | return res | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from abc import ABC | from abc import ABC | ||||
| from copy import deepcopy | |||||
| from openai import OpenAI | from openai import OpenAI | ||||
| import openai | import openai | ||||
| from rag.nlp import is_english | |||||
| class Base(ABC): | class Base(ABC): | ||||
| def __init__(self, key, model_name): | def __init__(self, key, model_name): | ||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| if system: history.insert(0, {"role": "system", "content": system}) | if system: history.insert(0, {"role": "system", "content": system}) | ||||
| try: | try: | ||||
| res = self.client.chat.completions.create( | |||||
| response = self.client.chat.completions.create( | |||||
| model=self.model_name, | model=self.model_name, | ||||
| messages=history, | messages=history, | ||||
| **gen_conf) | **gen_conf) | ||||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||||
| ans = response.output.choices[0]['message']['content'].strip() | |||||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||||
| return ans, response.usage.completion_tokens | |||||
| except openai.APIError as e: | except openai.APIError as e: | ||||
| return "ERROR: "+str(e), 0 | |||||
| return "**ERROR**: "+str(e), 0 | |||||
| from dashscope import Generation | from dashscope import Generation | ||||
| result_format='message', | result_format='message', | ||||
| **gen_conf | **gen_conf | ||||
| ) | ) | ||||
| ans = "" | |||||
| tk_count = 0 | |||||
| if response.status_code == HTTPStatus.OK: | if response.status_code == HTTPStatus.OK: | ||||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||||
| return "ERROR: " + response.message, 0 | |||||
| ans += response.output.choices[0]['message']['content'] | |||||
| tk_count += response.usage.output_tokens | |||||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||||
| return ans, tk_count | |||||
| return "**ERROR**: " + response.message, tk_count | |||||
| from zhipuai import ZhipuAI | from zhipuai import ZhipuAI | ||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| if system: history.insert(0, {"role": "system", "content": system}) | if system: history.insert(0, {"role": "system", "content": system}) | ||||
| response = self.client.chat.completions.create( | |||||
| self.model_name, | |||||
| messages=history, | |||||
| **gen_conf | |||||
| ) | |||||
| if response.status_code == HTTPStatus.OK: | |||||
| return response.output.choices[0]['message']['content'], response.usage.completion_tokens | |||||
| return "ERROR: " + response.message, 0 | |||||
| try: | |||||
| response = self.client.chat.completions.create( | |||||
| self.model_name, | |||||
| messages=history, | |||||
| **gen_conf | |||||
| ) | |||||
| ans = response.output.choices[0]['message']['content'].strip() | |||||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||||
| return ans, response.usage.completion_tokens | |||||
| except Exception as e: | |||||
| return "**ERROR**: " + str(e), 0 |
| chunks_tks, | chunks_tks, | ||||
| tkweight, vtweight) | tkweight, vtweight) | ||||
| mx = np.max(sim) * 0.99 | mx = np.max(sim) * 0.99 | ||||
| if mx < 0.35: | |||||
| if mx < 0.66: | |||||
| continue | continue | ||||
| cites[idx[i]] = list( | cites[idx[i]] = list( | ||||
| set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | ||||
| res = "" | res = "" | ||||
| seted = set([]) | |||||
| for i, p in enumerate(pieces): | for i, p in enumerate(pieces): | ||||
| res += p | res += p | ||||
| if i not in idx: | if i not in idx: | ||||
| if i not in cites: | if i not in cites: | ||||
| continue | continue | ||||
| for c in cites[i]: assert int(c) < len(chunk_v) | for c in cites[i]: assert int(c) < len(chunk_v) | ||||
| for c in cites[i]: res += f" ##{c}$$" | |||||
| for c in cites[i]: | |||||
| if c in seted:continue | |||||
| res += f" ##{c}$$" | |||||
| seted.add(c) | |||||
| return res | return res | ||||
| if dnm not in ranks["doc_aggs"]: | if dnm not in ranks["doc_aggs"]: | ||||
| ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} | ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} | ||||
| ranks["doc_aggs"][dnm]["count"] += 1 | ranks["doc_aggs"][dnm]["count"] += 1 | ||||
| ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] | |||||
| ranks["doc_aggs"] = []#[{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] | |||||
| return ranks | return ranks | ||||