### What problem does this PR solve? #1069 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.8.0
| else: conv.reference[-1] = ans["reference"] | else: conv.reference[-1] = ans["reference"] | ||||
| conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | ||||
| def rename_field(ans): | |||||
| for chunk_i in ans['reference'].get('chunks', []): | |||||
| chunk_i['doc_name'] = chunk_i['docnm_kwd'] | |||||
| chunk_i.pop('docnm_kwd') | |||||
| def stream(): | def stream(): | ||||
| nonlocal dia, msg, req, conv | nonlocal dia, msg, req, conv | ||||
| try: | try: | ||||
| for ans in chat(dia, msg, True, **req): | for ans in chat(dia, msg, True, **req): | ||||
| fillin_conv(ans) | fillin_conv(ans) | ||||
| for chunk_i in ans['reference'].get('chunks', []): | |||||
| chunk_i['doc_name'] = chunk_i['docnm_kwd'] | |||||
| chunk_i.pop('docnm_kwd') | |||||
| yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||||
| rename_field(rename_field) | |||||
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | API4ConversationService.append_message(conv.id, conv.to_dict()) | ||||
| except Exception as e: | except Exception as e: | ||||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | ||||
| "content": "" | "content": "" | ||||
| } | } | ||||
| ] | ] | ||||
| for ans in chat(dia, msg, stream=False, **req): | |||||
| # answer = ans | |||||
| data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) | |||||
| fillin_conv(ans) | |||||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||||
| chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] | |||||
| for chunk_idx in chunk_idxs[:1]: | |||||
| if ans["reference"]["chunks"][chunk_idx]["img_id"]: | |||||
| try: | |||||
| bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") | |||||
| response = MINIO.get(bkt, nm) | |||||
| data_type_picture["url"] = base64.b64encode(response).decode('utf-8') | |||||
| data.append(data_type_picture) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| ans = "" | |||||
| for a in chat(dia, msg, stream=False, **req): | |||||
| ans = a | |||||
| break | break | ||||
| data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) | |||||
| fillin_conv(ans) | |||||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||||
| chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] | |||||
| for chunk_idx in chunk_idxs[:1]: | |||||
| if ans["reference"]["chunks"][chunk_idx]["img_id"]: | |||||
| try: | |||||
| bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") | |||||
| response = MINIO.get(bkt, nm) | |||||
| data_type_picture["url"] = base64.b64encode(response).decode('utf-8') | |||||
| data.append(data_type_picture) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| response = {"code": 200, "msg": "success", "data": data} | response = {"code": 200, "msg": "success", "data": data} | ||||
| return response | return response |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import json | |||||
| from flask import request | |||||
| from flask_login import login_required, current_user | |||||
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | |||||
| from api.utils import get_uuid | |||||
| from api.utils.api_utils import get_json_result, server_error_response, validate_request | |||||
| from graph.canvas import Canvas | |||||
| @manager.route('/templates', methods=['GET']) | |||||
| @login_required | |||||
| def templates(): | |||||
| return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) | |||||
| @manager.route('/list', methods=['GET']) | |||||
| @login_required | |||||
| def canvas_list(): | |||||
| return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)]) | |||||
| @manager.route('/rm', methods=['POST']) | |||||
| @validate_request("canvas_ids") | |||||
| @login_required | |||||
| def rm(): | |||||
| for i in request.json["canvas_ids"]: | |||||
| UserCanvasService.delete_by_id(i) | |||||
| return get_json_result(data=True) | |||||
| @manager.route('/set', methods=['POST']) | |||||
| @validate_request("dsl", "title") | |||||
| @login_required | |||||
| def save(): | |||||
| req = request.json | |||||
| req["user_id"] = current_user.id | |||||
| if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) | |||||
| try: | |||||
| Canvas(req["dsl"]) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| req["dsl"] = json.loads(req["dsl"]) | |||||
| if "id" not in req: | |||||
| req["id"] = get_uuid() | |||||
| if not UserCanvasService.save(**req): | |||||
| return server_error_response("Fail to save canvas.") | |||||
| else: | |||||
| UserCanvasService.update_by_id(req["id"], req) | |||||
| return get_json_result(data=req) | |||||
| @manager.route('/get/<canvas_id>', methods=['GET']) | |||||
| @login_required | |||||
| def get(canvas_id): | |||||
| e, c = UserCanvasService.get_by_id(canvas_id) | |||||
| if not e: | |||||
| return server_error_response("canvas not found.") | |||||
| return get_json_result(data=c.to_dict()) | |||||
| @manager.route('/run', methods=['POST']) | |||||
| @validate_request("id", "dsl") | |||||
| @login_required | |||||
| def run(): | |||||
| req = request.json | |||||
| if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) | |||||
| try: | |||||
| canvas = Canvas(req["dsl"], current_user.id) | |||||
| ans = canvas.run() | |||||
| req["dsl"] = json.loads(str(canvas)) | |||||
| UserCanvasService.update_by_id(req["id"], dsl=req["dsl"]) | |||||
| return get_json_result(data=req["dsl"]) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/reset', methods=['POST']) | |||||
| @validate_request("canvas_id") | |||||
| @login_required | |||||
| def reset(): | |||||
| req = request.json | |||||
| try: | |||||
| user_canvas = UserCanvasService.get_by_id(req["canvas_id"]) | |||||
| canvas = Canvas(req["dsl"], current_user.id) | |||||
| canvas.reset() | |||||
| req["dsl"] = json.loads(str(canvas)) | |||||
| UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"]) | |||||
| return get_json_result(data=req["dsl"]) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from flask import request, Response, jsonify | |||||
| from copy import deepcopy | |||||
| from flask import request, Response | |||||
| from flask_login import login_required | from flask_login import login_required | ||||
| from api.db.services.dialog_service import DialogService, ConversationService, chat | from api.db.services.dialog_service import DialogService, ConversationService, chat | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| e, conv = ConversationService.get_by_id(req["conversation_id"]) | e, conv = ConversationService.get_by_id(req["conversation_id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Conversation not found!") | return get_data_error_result(retmsg="Conversation not found!") | ||||
| conv.message.append(msg[-1]) | |||||
| conv.message.append(deepcopy(msg[-1])) | |||||
| e, dia = DialogService.get_by_id(conv.dialog_id) | e, dia = DialogService.get_by_id(conv.dialog_id) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Dialog not found!") | return get_data_error_result(retmsg="Dialog not found!") |
| req = request.json | req = request.json | ||||
| dialog_id = req.get("dialog_id") | dialog_id = req.get("dialog_id") | ||||
| name = req.get("name", "New Dialog") | name = req.get("name", "New Dialog") | ||||
| icon = req.get("icon", "") | |||||
| description = req.get("description", "A helpful Dialog") | description = req.get("description", "A helpful Dialog") | ||||
| icon = req.get("icon", "") | |||||
| top_n = req.get("top_n", 6) | top_n = req.get("top_n", 6) | ||||
| top_k = req.get("top_k", 1024) | top_k = req.get("top_k", 1024) | ||||
| rerank_id = req.get("rerank_id", "") | rerank_id = req.get("rerank_id", "") | ||||
| "rerank_id": rerank_id, | "rerank_id": rerank_id, | ||||
| "similarity_threshold": similarity_threshold, | "similarity_threshold": similarity_threshold, | ||||
| "vector_similarity_weight": vector_similarity_weight, | "vector_similarity_weight": vector_similarity_weight, | ||||
| "icon": icon, | |||||
| "icon": icon | |||||
| } | } | ||||
| if not DialogService.save(**dia): | if not DialogService.save(**dia): | ||||
| return get_data_error_result(retmsg="Fail to new a dialog!") | return get_data_error_result(retmsg="Fail to new a dialog!") |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from datetime import datetime | |||||
| import peewee | |||||
| from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas | |||||
| from api.db.services.common_service import CommonService | |||||
| class CanvasTemplateService(CommonService): | |||||
| model = CanvasTemplate | |||||
| class UserCanvasService(CommonService): | |||||
| model = UserCanvas |
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | ||||
| from api.settings import chat_logger, retrievaler | from api.settings import chat_logger, retrievaler | ||||
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| from rag.nlp.rag_tokenizer import is_chinese | |||||
| from rag.nlp.search import index_name | from rag.nlp.search import index_name | ||||
| from rag.utils import rmSpace, num_tokens_from_string, encoder | from rag.utils import rmSpace, num_tokens_from_string, encoder | ||||
| if not llm: | if not llm: | ||||
| raise LookupError("LLM(%s) not found" % dialog.llm_id) | raise LookupError("LLM(%s) not found" % dialog.llm_id) | ||||
| max_tokens = 1024 | max_tokens = 1024 | ||||
| else: max_tokens = llm[0].max_tokens | |||||
| else: | |||||
| max_tokens = llm[0].max_tokens | |||||
| kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) | kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) | ||||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | embd_nms = list(set([kb.embd_id for kb in kbs])) | ||||
| if len(embd_nms) != 1: | if len(embd_nms) != 1: | ||||
| doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | ||||
| top=1024, aggs=False, rerank_mdl=rerank_mdl) | top=1024, aggs=False, rerank_mdl=rerank_mdl) | ||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | ||||
| #self-rag | |||||
| if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges): | |||||
| questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1]) | |||||
| kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | |||||
| dialog.similarity_threshold, | |||||
| dialog.vector_similarity_weight, | |||||
| doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | |||||
| top=1024, aggs=False, rerank_mdl=rerank_mdl) | |||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||||
| chat_logger.info( | chat_logger.info( | ||||
| "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | ||||
| msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] | msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] | ||||
| msg.extend([{"role": m["role"], "content": m["content"]} | msg.extend([{"role": m["role"], "content": m["content"]} | ||||
| for m in messages if m["role"] != "system"]) | |||||
| for m in messages if m["role"] != "system"]) | |||||
| used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) | used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) | ||||
| assert len(msg) >= 2, f"message_fit_in has bug: {msg}" | assert len(msg) >= 2, f"message_fit_in has bug: {msg}" | ||||
| if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): | if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): | ||||
| answer, idx = retrievaler.insert_citations(answer, | answer, idx = retrievaler.insert_citations(answer, | ||||
| [ck["content_ltks"] | [ck["content_ltks"] | ||||
| for ck in kbinfos["chunks"]], | |||||
| for ck in kbinfos["chunks"]], | |||||
| [ck["vector"] | [ck["vector"] | ||||
| for ck in kbinfos["chunks"]], | |||||
| for ck in kbinfos["chunks"]], | |||||
| embd_mdl, | embd_mdl, | ||||
| tkweight=1 - dialog.vector_similarity_weight, | tkweight=1 - dialog.vector_similarity_weight, | ||||
| vtweight=dialog.vector_similarity_weight) | vtweight=dialog.vector_similarity_weight) | ||||
| for c in refs["chunks"]: | for c in refs["chunks"]: | ||||
| if c.get("vector"): | if c.get("vector"): | ||||
| del c["vector"] | del c["vector"] | ||||
| if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: | |||||
| if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: | |||||
| answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" | answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" | ||||
| return {"answer": answer, "reference": refs} | return {"answer": answer, "reference": refs} | ||||
| def get_table(): | def get_table(): | ||||
| nonlocal sys_prompt, user_promt, question, tried_times | nonlocal sys_prompt, user_promt, question, tried_times | ||||
| sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { | sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { | ||||
| "temperature": 0.06}) | |||||
| "temperature": 0.06}) | |||||
| print(user_promt, sql) | print(user_promt, sql) | ||||
| chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") | chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") | ||||
| sql = re.sub(r"[\r\n]+", " ", sql.lower()) | sql = re.sub(r"[\r\n]+", " ", sql.lower()) | ||||
| # compose markdown table | # compose markdown table | ||||
| clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], | clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], | ||||
| tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | |||||
| tbl["columns"][i]["name"])) for i in | |||||
| clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | |||||
| line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ | line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ | ||||
| ("|------|" if docid_idx and docid_idx else "") | |||||
| ("|------|" if docid_idx and docid_idx else "") | |||||
| rows = ["|" + | rows = ["|" + | ||||
| "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + | "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + | ||||
| "|" for r in tbl["rows"]] | "|" for r in tbl["rows"]] | ||||
| if quota: | if quota: | ||||
| rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) | rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) | ||||
| else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) | |||||
| else: | |||||
| rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) | |||||
| rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) | rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) | ||||
| if not docid_idx or not docnm_idx: | if not docid_idx or not docnm_idx: | ||||
| return { | return { | ||||
| "answer": "\n".join([clmns, line, rows]), | "answer": "\n".join([clmns, line, rows]), | ||||
| "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], | "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], | ||||
| "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} | |||||
| "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in | |||||
| doc_aggs.items()]} | |||||
| } | } | ||||
| def relevant(tenant_id, llm_id, question, contents: list): | |||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||||
| prompt = """ | |||||
| You are a grader assessing relevance of a retrieved document to a user question. | |||||
| It does not need to be a stringent test. The goal is to filter out erroneous retrievals. | |||||
| If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. | |||||
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. | |||||
| No other words needed except 'yes' or 'no'. | |||||
| """ | |||||
| if not contents:return False | |||||
| contents = "Documents: \n" + " - ".join(contents) | |||||
| contents = f"Question: {question}\n" + contents | |||||
| if num_tokens_from_string(contents) >= chat_mdl.max_length - 4: | |||||
| contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4]) | |||||
| ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01}) | |||||
| if ans.lower().find("yes") >= 0: return True | |||||
| return False | |||||
| def rewrite(tenant_id, llm_id, question): | |||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||||
| prompt = """ | |||||
| You are an expert at query expansion to generate a paraphrasing of a question. | |||||
| I can't retrieval relevant information from the knowledge base by using user's question directly. | |||||
| You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase, | |||||
| writing the abbreviation in its entirety, adding some extra descriptions or explanations, | |||||
| changing the way of expression, translating the original question into another language (English/Chinese), etc. | |||||
| And return 5 versions of question and one is from translation. | |||||
| Just list the question. No other words are needed. | |||||
| """ | |||||
| ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8}) | |||||
| return ans |
| self.page_cum_height = np.cumsum(self.page_cum_height) | self.page_cum_height = np.cumsum(self.page_cum_height) | ||||
| assert len(self.page_cum_height) == len(self.page_images) + 1 | assert len(self.page_cum_height) == len(self.page_images) + 1 | ||||
| if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from, | |||||
| page_to, callback) | |||||
| def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | ||||
| self.__images__(fnm, zoomin) | self.__images__(fnm, zoomin) |
| return np.array(res), token_count | return np.array(res), token_count | ||||
| @staticmethod | @staticmethod | ||||
| def rmWWW(txt): | def rmWWW(txt): | ||||
| patts = [ | patts = [ | ||||
| (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), | |||||
| (r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), | |||||
| (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), | (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), | ||||
| (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ") | (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ") | ||||
| ] | ] | ||||
| if not self.isChinese(txt): | if not self.isChinese(txt): | ||||
| tks = rag_tokenizer.tokenize(txt).split(" ") | tks = rag_tokenizer.tokenize(txt).split(" ") | ||||
| tks_w = self.tw.weights(tks) | tks_w = self.tw.weights(tks) | ||||
| tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w] | |||||
| tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] | |||||
| tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] | |||||
| tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] | |||||
| q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] | q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] | ||||
| for i in range(1, len(tks_w)): | for i in range(1, len(tks_w)): | ||||
| q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) | q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) | ||||
| if sm: | if sm: | ||||
| tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % ( | tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % ( | ||||
| " ".join(sm), " ".join(sm)) | " ".join(sm), " ".join(sm)) | ||||
| tms.append((tk, w)) | |||||
| if tk.strip(): | |||||
| tms.append((tk, w)) | |||||
| tms = " ".join([f"({t})^{w}" for t, w in tms]) | tms = " ".join([f"({t})^{w}" for t, w in tms]) | ||||