### What problem does this PR solve? Fix errors detected by Ruff ### Type of change - [x] Refactoringtags/v0.15.0
| @@ -133,7 +133,8 @@ class Canvas(ABC): | |||
| "components": {} | |||
| } | |||
| for k in self.dsl.keys(): | |||
| if k in ["components"]:continue | |||
| if k in ["components"]: | |||
| continue | |||
| dsl[k] = deepcopy(self.dsl[k]) | |||
| for k, cpn in self.components.items(): | |||
| @@ -158,7 +159,8 @@ class Canvas(ABC): | |||
| def get_compnent_name(self, cid): | |||
| for n in self.dsl["graph"]["nodes"]: | |||
| if cid == n["id"]: return n["data"]["name"] | |||
| if cid == n["id"]: | |||
| return n["data"]["name"] | |||
| return "" | |||
| def run(self, **kwargs): | |||
| @@ -173,7 +175,8 @@ class Canvas(ABC): | |||
| if kwargs.get("stream"): | |||
| for an in ans(): | |||
| yield an | |||
| else: yield ans | |||
| else: | |||
| yield ans | |||
| return | |||
| if not self.path: | |||
| @@ -188,7 +191,8 @@ class Canvas(ABC): | |||
| def prepare2run(cpns): | |||
| nonlocal ran, ans | |||
| for c in cpns: | |||
| if self.path[-1] and c == self.path[-1][-1]: continue | |||
| if self.path[-1] and c == self.path[-1][-1]: | |||
| continue | |||
| cpn = self.components[c]["obj"] | |||
| if cpn.component_name == "Answer": | |||
| self.answer.append(c) | |||
| @@ -197,7 +201,8 @@ class Canvas(ABC): | |||
| if c not in without_dependent_checking: | |||
| cpids = cpn.get_dependent_components() | |||
| if any([cc not in self.path[-1] for cc in cpids]): | |||
| if c not in waiting: waiting.append(c) | |||
| if c not in waiting: | |||
| waiting.append(c) | |||
| continue | |||
| yield "*'{}'* is running...🕞".format(self.get_compnent_name(c)) | |||
| ans = cpn.run(self.history, **kwargs) | |||
| @@ -211,10 +216,12 @@ class Canvas(ABC): | |||
| logging.debug(f"Canvas.run: {ran} {self.path}") | |||
| cpn_id = self.path[-1][ran] | |||
| cpn = self.get_component(cpn_id) | |||
| if not cpn["downstream"]: break | |||
| if not cpn["downstream"]: | |||
| break | |||
| loop = self._find_loop() | |||
| if loop: raise OverflowError(f"Too much loops: {loop}") | |||
| if loop: | |||
| raise OverflowError(f"Too much loops: {loop}") | |||
| if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: | |||
| switch_out = cpn["obj"].output()[1].iloc[0, 0] | |||
| @@ -283,19 +290,22 @@ class Canvas(ABC): | |||
| def _find_loop(self, max_loops=6): | |||
| path = self.path[-1][::-1] | |||
| if len(path) < 2: return False | |||
| if len(path) < 2: | |||
| return False | |||
| for i in range(len(path)): | |||
| if path[i].lower().find("answer") >= 0: | |||
| path = path[:i] | |||
| break | |||
| if len(path) < 2: return False | |||
| if len(path) < 2: | |||
| return False | |||
| for l in range(2, len(path) // 2): | |||
| pat = ",".join(path[0:l]) | |||
| for loc in range(2, len(path) // 2): | |||
| pat = ",".join(path[0:loc]) | |||
| path_str = ",".join(path) | |||
| if len(pat) >= len(path_str): return False | |||
| if len(pat) >= len(path_str): | |||
| return False | |||
| loop = max_loops | |||
| while path_str.find(pat) == 0 and loop >= 0: | |||
| loop -= 1 | |||
| @@ -303,7 +313,7 @@ class Canvas(ABC): | |||
| return False | |||
| path_str = path_str[len(pat)+1:] | |||
| if loop < 0: | |||
| pat = " => ".join([p.split(":")[0] for p in path[0:l]]) | |||
| pat = " => ".join([p.split(":")[0] for p in path[0:loc]]) | |||
| return pat + " => " + pat | |||
| return False | |||
| @@ -39,3 +39,73 @@ def component_class(class_name): | |||
| m = importlib.import_module("agent.component") | |||
| c = getattr(m, class_name) | |||
| return c | |||
| __all__ = [ | |||
| "Begin", | |||
| "BeginParam", | |||
| "Generate", | |||
| "GenerateParam", | |||
| "Retrieval", | |||
| "RetrievalParam", | |||
| "Answer", | |||
| "AnswerParam", | |||
| "Categorize", | |||
| "CategorizeParam", | |||
| "Switch", | |||
| "SwitchParam", | |||
| "Relevant", | |||
| "RelevantParam", | |||
| "Message", | |||
| "MessageParam", | |||
| "RewriteQuestion", | |||
| "RewriteQuestionParam", | |||
| "KeywordExtract", | |||
| "KeywordExtractParam", | |||
| "Concentrator", | |||
| "ConcentratorParam", | |||
| "Baidu", | |||
| "BaiduParam", | |||
| "DuckDuckGo", | |||
| "DuckDuckGoParam", | |||
| "Wikipedia", | |||
| "WikipediaParam", | |||
| "PubMed", | |||
| "PubMedParam", | |||
| "ArXiv", | |||
| "ArXivParam", | |||
| "Google", | |||
| "GoogleParam", | |||
| "Bing", | |||
| "BingParam", | |||
| "GoogleScholar", | |||
| "GoogleScholarParam", | |||
| "DeepL", | |||
| "DeepLParam", | |||
| "GitHub", | |||
| "GitHubParam", | |||
| "BaiduFanyi", | |||
| "BaiduFanyiParam", | |||
| "QWeather", | |||
| "QWeatherParam", | |||
| "ExeSQL", | |||
| "ExeSQLParam", | |||
| "YahooFinance", | |||
| "YahooFinanceParam", | |||
| "WenCai", | |||
| "WenCaiParam", | |||
| "Jin10", | |||
| "Jin10Param", | |||
| "TuShare", | |||
| "TuShareParam", | |||
| "AkShare", | |||
| "AkShareParam", | |||
| "Crawler", | |||
| "CrawlerParam", | |||
| "Invoke", | |||
| "InvokeParam", | |||
| "Template", | |||
| "TemplateParam", | |||
| "Email", | |||
| "EmailParam", | |||
| "component_class" | |||
| ] | |||
| @@ -428,7 +428,8 @@ class ComponentBase(ABC): | |||
| def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: | |||
| o = getattr(self._param, self._param.output_var_name) | |||
| if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): | |||
| if not isinstance(o, list): o = [o] | |||
| if not isinstance(o, list): | |||
| o = [o] | |||
| o = pd.DataFrame(o) | |||
| if allow_partial or not isinstance(o, partial): | |||
| @@ -440,7 +441,8 @@ class ComponentBase(ABC): | |||
| for oo in o(): | |||
| if not isinstance(oo, pd.DataFrame): | |||
| outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]) | |||
| else: outs = oo | |||
| else: | |||
| outs = oo | |||
| return self._param.output_var_name, outs | |||
| def reset(self): | |||
| @@ -482,13 +484,15 @@ class ComponentBase(ABC): | |||
| outs.append(pd.DataFrame([{"content": q["value"]}])) | |||
| if outs: | |||
| df = pd.concat(outs, ignore_index=True) | |||
| if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True) | |||
| if "content" in df: | |||
| df = df.drop_duplicates(subset=['content']).reset_index(drop=True) | |||
| return df | |||
| upstream_outs = [] | |||
| for u in reversed_cpnts[::-1]: | |||
| if self.get_component_name(u) in ["switch", "concentrator"]: continue | |||
| if self.get_component_name(u) in ["switch", "concentrator"]: | |||
| continue | |||
| if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": | |||
| o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] | |||
| if o is not None: | |||
| @@ -532,7 +536,8 @@ class ComponentBase(ABC): | |||
| reversed_cpnts.extend(self._canvas.path[-1]) | |||
| for u in reversed_cpnts[::-1]: | |||
| if self.get_component_name(u) in ["switch", "answer"]: continue | |||
| if self.get_component_name(u) in ["switch", "answer"]: | |||
| continue | |||
| return self._canvas.get_component(u)["obj"].output()[1] | |||
| @staticmethod | |||
| @@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam): | |||
| super().check() | |||
| self.check_empty(self.category_description, "[Categorize] Category examples") | |||
| for k, v in self.category_description.items(): | |||
| if not k: raise ValueError("[Categorize] Category name can not be empty!") | |||
| if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") | |||
| if not k: | |||
| raise ValueError("[Categorize] Category name can not be empty!") | |||
| if not v.get("to"): | |||
| raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") | |||
| def get_prompt(self): | |||
| cate_lines = [] | |||
| for c, desc in self.category_description.items(): | |||
| for l in desc.get("examples", "").split("\n"): | |||
| if not l: continue | |||
| cate_lines.append("Question: {}\tCategory: {}".format(l, c)) | |||
| for line in desc.get("examples", "").split("\n"): | |||
| if not line: | |||
| continue | |||
| cate_lines.append("Question: {}\tCategory: {}".format(line, c)) | |||
| descriptions = [] | |||
| for c, desc in self.category_description.items(): | |||
| if desc.get("description"): | |||
| @@ -14,7 +14,6 @@ | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| import re | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| import deepl | |||
| @@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase): | |||
| self.check_empty(self.password, "Database password") | |||
| self.check_positive_integer(self.top_n, "Number of records") | |||
| if self.database == "rag_flow": | |||
| if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.") | |||
| if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.") | |||
| if self.host == "ragflow-mysql": | |||
| raise ValueError("The host is not accessible.") | |||
| if self.password == "infini_rag_flow": | |||
| raise ValueError("The host is not accessible.") | |||
| class ExeSQL(ComponentBase, ABC): | |||
| @@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase): | |||
| def gen_conf(self): | |||
| conf = {} | |||
| if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens | |||
| if self.temperature > 0: conf["temperature"] = self.temperature | |||
| if self.top_p > 0: conf["top_p"] = self.top_p | |||
| if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty | |||
| if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty | |||
| if self.max_tokens > 0: | |||
| conf["max_tokens"] = self.max_tokens | |||
| if self.temperature > 0: | |||
| conf["temperature"] = self.temperature | |||
| if self.top_p > 0: | |||
| conf["top_p"] = self.top_p | |||
| if self.presence_penalty > 0: | |||
| conf["presence_penalty"] = self.presence_penalty | |||
| if self.frequency_penalty > 0: | |||
| conf["frequency_penalty"] = self.frequency_penalty | |||
| return conf | |||
| @@ -83,7 +88,8 @@ class Generate(ComponentBase): | |||
| recall_docs = [] | |||
| for i in idx: | |||
| did = retrieval_res.loc[int(i), "doc_id"] | |||
| if did in doc_ids: continue | |||
| if did in doc_ids: | |||
| continue | |||
| doc_ids.add(did) | |||
| recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]}) | |||
| @@ -108,7 +114,8 @@ class Generate(ComponentBase): | |||
| retrieval_res = [] | |||
| self._param.inputs = [] | |||
| for para in self._param.parameters: | |||
| if not para.get("component_id"): continue | |||
| if not para.get("component_id"): | |||
| continue | |||
| component_id = para["component_id"].split("@")[0] | |||
| if para["component_id"].lower().find("@") >= 0: | |||
| cpn_id, key = para["component_id"].split("@") | |||
| @@ -142,7 +149,8 @@ class Generate(ComponentBase): | |||
| if retrieval_res: | |||
| retrieval_res = pd.concat(retrieval_res, ignore_index=True) | |||
| else: retrieval_res = pd.DataFrame([]) | |||
| else: | |||
| retrieval_res = pd.DataFrame([]) | |||
| for n, v in kwargs.items(): | |||
| prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt) | |||
| @@ -164,9 +172,11 @@ class Generate(ComponentBase): | |||
| return pd.DataFrame([res]) | |||
| msg = self._canvas.get_history(self._param.message_history_window_size) | |||
| if len(msg) < 1: msg.append({"role": "user", "content": ""}) | |||
| if len(msg) < 1: | |||
| msg.append({"role": "user", "content": ""}) | |||
| _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) | |||
| if len(msg) < 2: msg.append({"role": "user", "content": ""}) | |||
| if len(msg) < 2: | |||
| msg.append({"role": "user", "content": ""}) | |||
| ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) | |||
| if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: | |||
| @@ -185,9 +195,11 @@ class Generate(ComponentBase): | |||
| return | |||
| msg = self._canvas.get_history(self._param.message_history_window_size) | |||
| if len(msg) < 1: msg.append({"role": "user", "content": ""}) | |||
| if len(msg) < 1: | |||
| msg.append({"role": "user", "content": ""}) | |||
| _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) | |||
| if len(msg) < 2: msg.append({"role": "user", "content": ""}) | |||
| if len(msg) < 2: | |||
| msg.append({"role": "user", "content": ""}) | |||
| answer = "" | |||
| for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()): | |||
| res = {"content": ans, "reference": []} | |||
| @@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC): | |||
| hist = self._canvas.get_history(4) | |||
| conv = [] | |||
| for m in hist: | |||
| if m["role"] not in ["user", "assistant"]: continue | |||
| if m["role"] not in ["user", "assistant"]: | |||
| continue | |||
| conv.append("{}: {}".format(m["role"].upper(), m["content"])) | |||
| conv = "\n".join(conv) | |||
| @@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase): | |||
| def check(self): | |||
| self.check_empty(self.conditions, "[Switch] conditions") | |||
| for cond in self.conditions: | |||
| if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!") | |||
| if not cond["to"]: | |||
| raise ValueError("[Switch] 'To' can not be empty!") | |||
| class Switch(ComponentBase, ABC): | |||
| @@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC): | |||
| res = [] | |||
| for cond in self._param.conditions: | |||
| for item in cond["items"]: | |||
| if not item["cpn_id"]: continue | |||
| if not item["cpn_id"]: | |||
| continue | |||
| if item["cpn_id"].find("begin") >= 0: | |||
| continue | |||
| cid = item["cpn_id"].split("@")[0] | |||
| @@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC): | |||
| for cond in self._param.conditions: | |||
| res = [] | |||
| for item in cond["items"]: | |||
| if not item["cpn_id"]:continue | |||
| if not item["cpn_id"]: | |||
| continue | |||
| cid = item["cpn_id"].split("@")[0] | |||
| if item["cpn_id"].find("@") > 0: | |||
| cpn_id, key = item["cpn_id"].split("@") | |||
| @@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC): | |||
| elif operator == ">": | |||
| try: | |||
| return True if float(input) > float(value) else False | |||
| except Exception as e: | |||
| except Exception: | |||
| return True if input > value else False | |||
| elif operator == "<": | |||
| try: | |||
| return True if float(input) < float(value) else False | |||
| except Exception as e: | |||
| except Exception: | |||
| return True if input < value else False | |||
| elif operator == "≥": | |||
| try: | |||
| return True if float(input) >= float(value) else False | |||
| except Exception as e: | |||
| except Exception: | |||
| return True if input >= value else False | |||
| elif operator == "≤": | |||
| try: | |||
| return True if float(input) <= float(value) else False | |||
| except Exception as e: | |||
| except Exception: | |||
| return True if input <= value else False | |||
| raise ValueError('Not supported operator' + operator) | |||
| @@ -47,7 +47,8 @@ class Template(ComponentBase): | |||
| self._param.inputs = [] | |||
| for para in self._param.parameters: | |||
| if not para.get("component_id"): continue | |||
| if not para.get("component_id"): | |||
| continue | |||
| component_id = para["component_id"].split("@")[0] | |||
| if para["component_id"].lower().find("@") >= 0: | |||
| cpn_id, key = para["component_id"].split("@") | |||
| @@ -43,6 +43,7 @@ if __name__ == '__main__': | |||
| else: | |||
| print(ans["content"]) | |||
| if DEBUG: print(canvas.path) | |||
| if DEBUG: | |||
| print(canvas.path) | |||
| question = input("\n==================== User =====================\n> ") | |||
| canvas.add_user_input(question) | |||
| @@ -142,7 +142,6 @@ def set_conversation(): | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req = request.json | |||
| try: | |||
| if objs[0].source == "agent": | |||
| e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id) | |||
| @@ -188,7 +187,8 @@ def completion(): | |||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| return get_data_error_result(message="Conversation not found!") | |||
| if "quote" not in req: req["quote"] = False | |||
| if "quote" not in req: | |||
| req["quote"] = False | |||
| msg = [] | |||
| for m in req["messages"]: | |||
| @@ -197,7 +197,8 @@ def completion(): | |||
| if m["role"] == "assistant" and not msg: | |||
| continue | |||
| msg.append(m) | |||
| if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() | |||
| if not msg[-1].get("id"): | |||
| msg[-1]["id"] = get_uuid() | |||
| message_id = msg[-1]["id"] | |||
| def fillin_conv(ans): | |||
| @@ -674,11 +675,13 @@ def completion_faq(): | |||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| return get_data_error_result(message="Conversation not found!") | |||
| if "quote" not in req: req["quote"] = True | |||
| if "quote" not in req: | |||
| req["quote"] = True | |||
| msg = [] | |||
| msg.append({"role": "user", "content": req["word"]}) | |||
| if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() | |||
| if not msg[-1].get("id"): | |||
| msg[-1]["id"] = get_uuid() | |||
| message_id = msg[-1]["id"] | |||
| def fillin_conv(ans): | |||
| @@ -13,10 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import json | |||
| import traceback | |||
| from functools import partial | |||
| from flask import request, Response | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | |||
| @@ -60,7 +58,8 @@ def rm(): | |||
| def save(): | |||
| req = request.json | |||
| req["user_id"] = current_user.id | |||
| if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) | |||
| if not isinstance(req["dsl"], str): | |||
| req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) | |||
| req["dsl"] = json.loads(req["dsl"]) | |||
| if "id" not in req: | |||
| @@ -153,7 +152,8 @@ def run(): | |||
| return resp | |||
| for answer in canvas.run(stream=False): | |||
| if answer.get("running_status"): continue | |||
| if answer.get("running_status"): | |||
| continue | |||
| final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" | |||
| canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) | |||
| if final_ans.get("reference"): | |||
| @@ -237,7 +237,8 @@ def create(): | |||
| e, kb = KnowledgebaseService.get_by_id(doc.kb_id) | |||
| if not e: | |||
| return get_data_error_result(message="Knowledgebase not found!") | |||
| if kb.pagerank: d["pagerank_fea"] = kb.pagerank | |||
| if kb.pagerank: | |||
| d["pagerank_fea"] = kb.pagerank | |||
| embd_id = DocumentService.get_embd_id(req["doc_id"]) | |||
| embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) | |||
| @@ -281,10 +281,12 @@ def thumbup(): | |||
| if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant": | |||
| if up_down: | |||
| msg["thumbup"] = True | |||
| if "feedback" in msg: del msg["feedback"] | |||
| if "feedback" in msg: | |||
| del msg["feedback"] | |||
| else: | |||
| msg["thumbup"] = False | |||
| if feedback: msg["feedback"] = feedback | |||
| if feedback: | |||
| msg["feedback"] = feedback | |||
| break | |||
| ConversationService.update_by_id(conv["id"], conv) | |||
| @@ -37,10 +37,12 @@ def set_dialog(): | |||
| top_n = req.get("top_n", 6) | |||
| top_k = req.get("top_k", 1024) | |||
| rerank_id = req.get("rerank_id", "") | |||
| if not rerank_id: req["rerank_id"] = "" | |||
| if not rerank_id: | |||
| req["rerank_id"] = "" | |||
| similarity_threshold = req.get("similarity_threshold", 0.1) | |||
| vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | |||
| if vector_similarity_weight is None: vector_similarity_weight = 0.3 | |||
| if vector_similarity_weight is None: | |||
| vector_similarity_weight = 0.3 | |||
| llm_setting = req.get("llm_setting", {}) | |||
| default_prompt = { | |||
| "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 | |||
| @@ -13,7 +13,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License | |||
| # | |||
| import json | |||
| import os.path | |||
| import pathlib | |||
| import re | |||
| @@ -90,7 +89,8 @@ def web_crawl(): | |||
| raise LookupError("Can't find this knowledgebase!") | |||
| blob = html2pdf(url) | |||
| if not blob: return server_error_response(ValueError("Download failure.")) | |||
| if not blob: | |||
| return server_error_response(ValueError("Download failure.")) | |||
| root_folder = FileService.get_root_folder(current_user.id) | |||
| pf_id = root_folder["id"] | |||
| @@ -290,7 +290,8 @@ def change_status(): | |||
| def rm(): | |||
| req = request.json | |||
| doc_ids = req["doc_id"] | |||
| if isinstance(doc_ids, str): doc_ids = [doc_ids] | |||
| if isinstance(doc_ids, str): | |||
| doc_ids = [doc_ids] | |||
| for doc_id in doc_ids: | |||
| if not DocumentService.accessible4deletion(doc_id, current_user.id): | |||
| @@ -351,8 +351,10 @@ def list_app(): | |||
| llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) | |||
| for o in objs: | |||
| if not o.api_key: continue | |||
| if o.llm_name + "@" + o.llm_factory in llm_set: continue | |||
| if not o.api_key: | |||
| continue | |||
| if o.llm_name + "@" + o.llm_factory in llm_set: | |||
| continue | |||
| llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) | |||
| res = {} | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | |||
| from api.db.services.canvas_service import UserCanvasService | |||
| from api.utils.api_utils import get_error_data_result, token_required | |||
| from api.utils.api_utils import get_result | |||
| from flask import request | |||
| @@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config | |||
| from rag.nlp import search | |||
| from rag.utils import rmSpace | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| import os | |||
| MAXIMUM_OF_UPLOADING_FILES = 256 | |||
| @@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id): | |||
| if not req.get("content"): | |||
| return get_error_data_result(message="`content` is required") | |||
| if "important_keywords" in req: | |||
| if type(req["important_keywords"]) != list: | |||
| if not isinstance(req["important_keywords"], list): | |||
| return get_error_data_result( | |||
| "`important_keywords` is required to be a list" | |||
| ) | |||
| if "questions" in req: | |||
| if type(req["questions"]) != list: | |||
| if not isinstance(req["questions"], list): | |||
| return get_error_data_result( | |||
| "`questions` is required to be a list" | |||
| ) | |||
| @@ -143,8 +143,10 @@ def completion(tenant_id, chat_id): | |||
| } | |||
| conv.message.append(question) | |||
| for m in conv.message: | |||
| if m["role"] == "system": continue | |||
| if m["role"] == "assistant" and not msg: continue | |||
| if m["role"] == "system": | |||
| continue | |||
| if m["role"] == "assistant" and not msg: | |||
| continue | |||
| msg.append(m) | |||
| message_id = msg[-1].get("id") | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| @@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id): | |||
| if m["role"] == "assistant" and not msg: | |||
| continue | |||
| msg.append(m) | |||
| if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() | |||
| if not msg[-1].get("id"): | |||
| msg[-1]["id"] = get_uuid() | |||
| message_id = msg[-1]["id"] | |||
| stream = req.get("stream", True) | |||
| @@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id): | |||
| return resp | |||
| for answer in canvas.run(stream=False): | |||
| if answer.get("running_status"): continue | |||
| if answer.get("running_status"): | |||
| continue | |||
| final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" | |||
| canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) | |||
| if final_ans.get("reference"): | |||
| @@ -330,7 +330,7 @@ def user_info_from_github(access_token): | |||
| headers=headers, | |||
| ).json() | |||
| user_info["email"] = next( | |||
| (email for email in email_info if email["primary"] == True), None | |||
| (email for email in email_info if email["primary"]), None | |||
| )["email"] | |||
| return user_info | |||
| @@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool: | |||
| for p in cls.__bases__: | |||
| if p in CONTINUOUS_FIELD_TYPE: | |||
| return True | |||
| elif p != Field and p != object: | |||
| elif p is not Field and p is not object: | |||
| if is_continuous_field(p): | |||
| return True | |||
| else: | |||
| @@ -170,7 +170,7 @@ def add_graph_templates(): | |||
| cnvs = json.load(open(os.path.join(dir, fnm), "r")) | |||
| try: | |||
| CanvasTemplateService.save(**cnvs) | |||
| except: | |||
| except Exception: | |||
| CanvasTemplateService.update_by_id(cnvs["id"], cnvs) | |||
| except Exception: | |||
| logging.exception("Add graph templates error: ") | |||
| @@ -15,13 +15,14 @@ | |||
| # | |||
| import pathlib | |||
| import re | |||
| from .user_service import UserService | |||
| from .user_service import UserService as UserService | |||
| def duplicate_name(query_func, **kwargs): | |||
| fnm = kwargs["name"] | |||
| objs = query_func(**kwargs) | |||
| if not objs: return fnm | |||
| if not objs: | |||
| return fnm | |||
| ext = pathlib.Path(fnm).suffix #.jpg | |||
| nm = re.sub(r"%s$"%ext, "", fnm) | |||
| r = re.search(r"\(([0-9]+)\)$", nm) | |||
| @@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs): | |||
| nm = re.sub(r"\([0-9]+\)$", "", nm) | |||
| c += 1 | |||
| nm = f"{nm}({c})" | |||
| if ext: nm += f"{ext}" | |||
| if ext: | |||
| nm += f"{ext}" | |||
| kwargs["name"] = nm | |||
| return duplicate_name(query_func, **kwargs) | |||
| @@ -64,7 +64,8 @@ class API4ConversationService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def stats(cls, tenant_id, from_date, to_date, source=None): | |||
| if len(to_date) == 10: to_date += " 23:59:59" | |||
| if len(to_date) == 10: | |||
| to_date += " 23:59:59" | |||
| return cls.model.select( | |||
| cls.model.create_date.truncate("day").alias("dt"), | |||
| peewee.fn.COUNT( | |||
| @@ -13,9 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from datetime import datetime | |||
| import peewee | |||
| from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas | |||
| from api.db.db_models import DB, CanvasTemplate, UserCanvas | |||
| from api.db.services.common_service import CommonService | |||
| @@ -115,7 +115,7 @@ class CommonService: | |||
| try: | |||
| obj = cls.model.query(id=pid)[0] | |||
| return True, obj | |||
| except Exception as e: | |||
| except Exception: | |||
| return False, None | |||
| @classmethod | |||
| @@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000): | |||
| return c, msg | |||
| ll = num_tokens_from_string(msg_[0]["content"]) | |||
| l = num_tokens_from_string(msg_[-1]["content"]) | |||
| if ll / (ll + l) > 0.8: | |||
| ll2 = num_tokens_from_string(msg_[-1]["content"]) | |||
| if ll / (ll + ll2) > 0.8: | |||
| m = msg_[0]["content"] | |||
| m = encoder.decode(encoder.encode(m)[:max_length - l]) | |||
| m = encoder.decode(encoder.encode(m)[:max_length - ll2]) | |||
| msg[0]["content"] = m | |||
| return max_length, msg | |||
| m = msg_[1]["content"] | |||
| m = encoder.decode(encoder.encode(m)[:max_length - l]) | |||
| m = encoder.decode(encoder.encode(m)[:max_length - ll2]) | |||
| msg[1]["content"] = m | |||
| return max_length, msg | |||
| @@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | |||
| recall_docs = [ | |||
| d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| if not recall_docs: recall_docs = kbinfos["doc_aggs"] | |||
| if not recall_docs: | |||
| recall_docs = kbinfos["doc_aggs"] | |||
| kbinfos["doc_aggs"] = recall_docs | |||
| refs = deepcopy(kbinfos) | |||
| @@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list): | |||
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. | |||
| No other words needed except 'yes' or 'no'. | |||
| """ | |||
| if not contents:return False | |||
| if not contents: | |||
| return False | |||
| contents = "Documents: \n" + " - ".join(contents) | |||
| contents = f"Question: {question}\n" + contents | |||
| if num_tokens_from_string(contents) >= chat_mdl.max_length - 4: | |||
| contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4]) | |||
| ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01}) | |||
| if ans.lower().find("yes") >= 0: return True | |||
| if ans.lower().find("yes") >= 0: | |||
| return True | |||
| return False | |||
| @@ -481,8 +484,10 @@ Requirements: | |||
| ] | |||
| _, msg = message_fit_in(msg, chat_mdl.max_length) | |||
| kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | |||
| if isinstance(kwd, tuple): kwd = kwd[0] | |||
| if kwd.find("**ERROR**") >=0: return "" | |||
| if isinstance(kwd, tuple): | |||
| kwd = kwd[0] | |||
| if kwd.find("**ERROR**") >=0: | |||
| return "" | |||
| return kwd | |||
| @@ -508,8 +513,10 @@ Requirements: | |||
| ] | |||
| _, msg = message_fit_in(msg, chat_mdl.max_length) | |||
| kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | |||
| if isinstance(kwd, tuple): kwd = kwd[0] | |||
| if kwd.find("**ERROR**") >= 0: return "" | |||
| if isinstance(kwd, tuple): | |||
| kwd = kwd[0] | |||
| if kwd.find("**ERROR**") >= 0: | |||
| return "" | |||
| return kwd | |||
| @@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages): | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||
| conv = [] | |||
| for m in messages: | |||
| if m["role"] not in ["user", "assistant"]: continue | |||
| if m["role"] not in ["user", "assistant"]: | |||
| continue | |||
| conv.append("{}: {}".format(m["role"].upper(), m["content"])) | |||
| conv = "\n".join(conv) | |||
| today = datetime.date.today().isoformat() | |||
| @@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}? | |||
| def tts(tts_mdl, text): | |||
| if not tts_mdl or not text: return | |||
| if not tts_mdl or not text: | |||
| return | |||
| bin = b"" | |||
| for chunk in tts_mdl.tts(text): | |||
| bin += chunk | |||
| @@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id): | |||
| idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | |||
| recall_docs = [ | |||
| d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| if not recall_docs: recall_docs = kbinfos["doc_aggs"] | |||
| if not recall_docs: | |||
| recall_docs = kbinfos["doc_aggs"] | |||
| kbinfos["doc_aggs"] = recall_docs | |||
| refs = deepcopy(kbinfos) | |||
| for c in refs["chunks"]: | |||
| @@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): | |||
| try: | |||
| mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, | |||
| ensure_ascii=False, indent=2) | |||
| if len(mind_map) < 32: raise Exception("Few content: " + mind_map) | |||
| if len(mind_map) < 32: | |||
| raise Exception("Few content: " + mind_map) | |||
| cks.append({ | |||
| "id": get_uuid(), | |||
| "doc_id": doc_id, | |||
| @@ -20,7 +20,7 @@ from api.db.db_models import DB | |||
| from api.db.db_models import File, File2Document | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.utils import current_timestamp, datetime_format, get_uuid | |||
| from api.utils import current_timestamp, datetime_format | |||
| class File2DocumentService(CommonService): | |||
| @@ -63,7 +63,7 @@ class File2DocumentService(CommonService): | |||
| def update_by_file_id(cls, file_id, obj): | |||
| obj["update_time"] = current_timestamp() | |||
| obj["update_date"] = datetime_format(datetime.now()) | |||
| num = cls.model.update(obj).where(cls.model.id == file_id).execute() | |||
| # num = cls.model.update(obj).where(cls.model.id == file_id).execute() | |||
| e, obj = cls.get_by_id(cls.model.id) | |||
| return obj | |||
| @@ -85,7 +85,8 @@ class FileService(CommonService): | |||
| .join(Document, on=(File2Document.document_id == Document.id)) | |||
| .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) | |||
| .where(cls.model.id == file_id)) | |||
| if not kbs: return [] | |||
| if not kbs: | |||
| return [] | |||
| kbs_info_list = [] | |||
| for kb in list(kbs.dicts()): | |||
| kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']}) | |||
| @@ -304,7 +305,8 @@ class FileService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def add_file_from_kb(cls, doc, kb_folder_id, tenant_id): | |||
| for _ in File2DocumentService.get_by_document_id(doc["id"]): return | |||
| for _ in File2DocumentService.get_by_document_id(doc["id"]): | |||
| return | |||
| file = { | |||
| "id": get_uuid(), | |||
| "parent_id": kb_folder_id, | |||
| @@ -107,7 +107,8 @@ class TenantLLMService(CommonService): | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) | |||
| if model_config: model_config = model_config.to_dict() | |||
| if model_config: | |||
| model_config = model_config.to_dict() | |||
| if not model_config: | |||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||
| @@ -57,28 +57,33 @@ class TaskService(CommonService): | |||
| Tenant.img2txt_id, | |||
| Tenant.asr_id, | |||
| Tenant.llm_id, | |||
| cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ | |||
| cls.model.update_time, | |||
| ] | |||
| docs = ( | |||
| cls.model.select(*fields) | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) | |||
| .where(cls.model.id == task_id) | |||
| ) | |||
| docs = list(docs.dicts()) | |||
| if not docs: return None | |||
| if not docs: | |||
| return None | |||
| msg = "\nTask has been received." | |||
| prog = random.random() / 10. | |||
| prog = random.random() / 10.0 | |||
| if docs[0]["retry_count"] >= 3: | |||
| msg = "\nERROR: Task is abandoned after 3 times attempts." | |||
| prog = -1 | |||
| cls.model.update(progress_msg=cls.model.progress_msg + msg, | |||
| progress=prog, | |||
| retry_count=docs[0]["retry_count"]+1 | |||
| ).where( | |||
| cls.model.id == docs[0]["id"]).execute() | |||
| cls.model.update( | |||
| progress_msg=cls.model.progress_msg + msg, | |||
| progress=prog, | |||
| retry_count=docs[0]["retry_count"] + 1, | |||
| ).where(cls.model.id == docs[0]["id"]).execute() | |||
| if docs[0]["retry_count"] >= 3: return None | |||
| if docs[0]["retry_count"] >= 3: | |||
| return None | |||
| return docs[0] | |||
| @@ -86,21 +91,44 @@ class TaskService(CommonService): | |||
| @DB.connection_context() | |||
| def get_ongoing_doc_name(cls): | |||
| with DB.lock("get_task", -1): | |||
| docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \ | |||
| .join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \ | |||
| docs = ( | |||
| cls.model.select( | |||
| *[Document.id, Document.kb_id, Document.location, File.parent_id] | |||
| ) | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) | |||
| .join( | |||
| File2Document, | |||
| on=(File2Document.document_id == Document.id), | |||
| join_type=JOIN.LEFT_OUTER, | |||
| ) | |||
| .join( | |||
| File, | |||
| on=(File2Document.file_id == File.id), | |||
| join_type=JOIN.LEFT_OUTER, | |||
| ) | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| Document.run == TaskStatus.RUNNING.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress < 1, | |||
| cls.model.create_time >= current_timestamp() - 1000 * 600 | |||
| cls.model.create_time >= current_timestamp() - 1000 * 600, | |||
| ) | |||
| ) | |||
| docs = list(docs.dicts()) | |||
| if not docs: return [] | |||
| return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs])) | |||
| if not docs: | |||
| return [] | |||
| return list( | |||
| set( | |||
| [ | |||
| ( | |||
| d["parent_id"] if d["parent_id"] else d["kb_id"], | |||
| d["location"], | |||
| ) | |||
| for d in docs | |||
| ] | |||
| ) | |||
| ) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -118,28 +146,30 @@ class TaskService(CommonService): | |||
| def update_progress(cls, id, info): | |||
| if os.environ.get("MACOS"): | |||
| if info["progress_msg"]: | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( | |||
| cls.model.id == id).execute() | |||
| cls.model.update( | |||
| progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] | |||
| ).where(cls.model.id == id).execute() | |||
| if "progress" in info: | |||
| cls.model.update(progress=info["progress"]).where( | |||
| cls.model.id == id).execute() | |||
| cls.model.id == id | |||
| ).execute() | |||
| return | |||
| with DB.lock("update_progress", -1): | |||
| if info["progress_msg"]: | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( | |||
| cls.model.id == id).execute() | |||
| cls.model.update( | |||
| progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] | |||
| ).where(cls.model.id == id).execute() | |||
| if "progress" in info: | |||
| cls.model.update(progress=info["progress"]).where( | |||
| cls.model.id == id).execute() | |||
| cls.model.id == id | |||
| ).execute() | |||
| def queue_tasks(doc: dict, bucket: str, name: str): | |||
| def new_task(): | |||
| return { | |||
| "id": get_uuid(), | |||
| "doc_id": doc["id"] | |||
| } | |||
| return {"id": get_uuid(), "doc_id": doc["id"]} | |||
| tsks = [] | |||
| if doc["type"] == FileType.PDF.value: | |||
| @@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| if doc["parser_id"] == "paper": | |||
| page_size = doc["parser_config"].get("task_page_size", 22) | |||
| if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout: | |||
| page_size = 10 ** 9 | |||
| page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)] | |||
| page_size = 10**9 | |||
| page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)] | |||
| for s, e in page_ranges: | |||
| s -= 1 | |||
| s = max(0, s) | |||
| @@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| DocumentService.begin2parse(doc["id"]) | |||
| for t in tsks: | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." | |||
| assert REDIS_CONN.queue_product( | |||
| SVR_QUEUE_NAME, message=t | |||
| ), "Can't access Redis. Please check the Redis' status." | |||
| @@ -22,7 +22,7 @@ from api.db import UserTenantRole | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import User, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format | |||
| from api.utils import get_uuid, current_timestamp, datetime_format | |||
| from api.db import StatusEnum | |||
| @@ -21,10 +21,7 @@ | |||
| import logging | |||
| import os | |||
| from api.utils.log_utils import initRootLogger | |||
| LOG_LEVELS = os.environ.get("LOG_LEVELS", "") | |||
| initRootLogger("ragflow_server", LOG_LEVELS) | |||
| import os | |||
| import signal | |||
| import sys | |||
| import time | |||
| @@ -44,6 +41,9 @@ from api.versions import get_ragflow_version | |||
| from api.utils import show_configs | |||
| from rag.settings import print_rag_settings | |||
| LOG_LEVELS = os.environ.get("LOG_LEVELS", "") | |||
| initRootLogger("ragflow_server", LOG_LEVELS) | |||
| def update_progress(): | |||
| while True: | |||
| @@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES | |||
| from api.db.db_models import APIToken | |||
| from api import settings | |||
| from api import settings | |||
| from api.utils import CustomJSONEncoder, get_uuid | |||
| from api.utils import json_dumps | |||
| from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC | |||
| @@ -45,5 +45,5 @@ try: | |||
| pool = Pool(processes=1) | |||
| thread = pool.apply_async(download_nltk_data) | |||
| binary = thread.get(timeout=60) | |||
| except Exception as e: | |||
| except Exception: | |||
| print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True) | |||
| @@ -18,4 +18,16 @@ from .ppt_parser import RAGFlowPptParser as PptParser | |||
| from .html_parser import RAGFlowHtmlParser as HtmlParser | |||
| from .json_parser import RAGFlowJsonParser as JsonParser | |||
| from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser | |||
| from .txt_parser import RAGFlowTxtParser as TxtParser | |||
| from .txt_parser import RAGFlowTxtParser as TxtParser | |||
| __all__ = [ | |||
| "PdfParser", | |||
| "PlainParser", | |||
| "DocxParser", | |||
| "ExcelParser", | |||
| "PptParser", | |||
| "HtmlParser", | |||
| "JsonParser", | |||
| "MarkdownParser", | |||
| "TxtParser", | |||
| ] | |||
| @@ -29,7 +29,8 @@ class RAGFlowExcelParser: | |||
| for sheetname in wb.sheetnames: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| if not rows: continue | |||
| if not rows: | |||
| continue | |||
| tb_rows_0 = "<tr>" | |||
| for t in list(rows[0]): | |||
| @@ -40,7 +41,9 @@ class RAGFlowExcelParser: | |||
| tb = "" | |||
| tb += f"<table><caption>{sheetname}</caption>" | |||
| tb += tb_rows_0 | |||
| for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]): | |||
| for r in list( | |||
| rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows] | |||
| ): | |||
| tb += "<tr>" | |||
| for i, c in enumerate(r): | |||
| if c.value is None: | |||
| @@ -62,20 +65,21 @@ class RAGFlowExcelParser: | |||
| for sheetname in wb.sheetnames: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| if not rows:continue | |||
| if not rows: | |||
| continue | |||
| ti = list(rows[0]) | |||
| for r in list(rows[1:]): | |||
| l = [] | |||
| fields = [] | |||
| for i, c in enumerate(r): | |||
| if not c.value: | |||
| continue | |||
| t = str(ti[i].value) if i < len(ti) else "" | |||
| t += (":" if t else "") + str(c.value) | |||
| l.append(t) | |||
| l = "; ".join(l) | |||
| fields.append(t) | |||
| line = "; ".join(fields) | |||
| if sheetname.lower().find("sheet") < 0: | |||
| l += " ——" + sheetname | |||
| res.append(l) | |||
| line += " ——" + sheetname | |||
| res.append(line) | |||
| return res | |||
| @staticmethod | |||
| @@ -36,7 +36,7 @@ class RAGFlowHtmlParser: | |||
| @classmethod | |||
| def parser_txt(cls, txt): | |||
| if type(txt) != str: | |||
| if not isinstance(txt, str): | |||
| raise TypeError("txt type should be str!") | |||
| html_doc = readability.Document(txt) | |||
| title = html_doc.title() | |||
| @@ -22,7 +22,7 @@ class RAGFlowJsonParser: | |||
| txt = binary.decode(encoding, errors="ignore") | |||
| json_data = json.loads(txt) | |||
| chunks = self.split_json(json_data, True) | |||
| sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l] | |||
| sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line] | |||
| return sections | |||
| @staticmethod | |||
| @@ -752,7 +752,7 @@ class RAGFlowPdfParser: | |||
| "x1": np.max([b["x1"] for b in bxs]), | |||
| "bottom": np.max([b["bottom"] for b in bxs]) - ht | |||
| } | |||
| louts = [l for l in self.page_layout[pn] if l["type"] == ltype] | |||
| louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype] | |||
| ii = Recognizer.find_overlapped(b, louts, naive=True) | |||
| if ii is not None: | |||
| b = louts[ii] | |||
| @@ -763,7 +763,8 @@ class RAGFlowPdfParser: | |||
| "layoutno", ""))) | |||
| left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"] | |||
| if right < left: right = left + 1 | |||
| if right < left: | |||
| right = left + 1 | |||
| poss.append((pn + self.page_from, left, right, top, bott)) | |||
| return self.page_images[pn] \ | |||
| .crop((left * ZM, top * ZM, | |||
| @@ -845,7 +846,8 @@ class RAGFlowPdfParser: | |||
| top = bx["top"] - self.page_cum_height[pn[0] - 1] | |||
| bott = bx["bottom"] - self.page_cum_height[pn[0] - 1] | |||
| page_images_cnt = len(self.page_images) | |||
| if pn[-1] - 1 >= page_images_cnt: return "" | |||
| if pn[-1] - 1 >= page_images_cnt: | |||
| return "" | |||
| while bott * ZM > self.page_images[pn[-1] - 1].size[1]: | |||
| bott -= self.page_images[pn[-1] - 1].size[1] / ZM | |||
| pn.append(pn[-1] + 1) | |||
| @@ -889,7 +891,6 @@ class RAGFlowPdfParser: | |||
| nonlocal mh, pw, lines, widths | |||
| lines.append(line) | |||
| widths.append(width(line)) | |||
| width_mean = np.mean(widths) | |||
| mmj = self.proj_match( | |||
| line["text"]) or line.get( | |||
| "layout_type", | |||
| @@ -994,7 +995,7 @@ class RAGFlowPdfParser: | |||
| else: | |||
| self.is_english = False | |||
| st = timer() | |||
| # st = timer() | |||
| for i, img in enumerate(self.page_images_x2): | |||
| chars = self.page_chars[i] if not self.is_english else [] | |||
| self.mean_height.append( | |||
| @@ -1028,8 +1029,8 @@ class RAGFlowPdfParser: | |||
| self.page_cum_height = np.cumsum(self.page_cum_height) | |||
| assert len(self.page_cum_height) == len(self.page_images) + 1 | |||
| if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from, | |||
| page_to, callback) | |||
| if len(self.boxes) == 0 and zoomin < 9: | |||
| self.__images__(fnm, zoomin * 3, page_from, page_to, callback) | |||
| def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | |||
| self.__images__(fnm, zoomin) | |||
| @@ -1168,7 +1169,7 @@ class PlainParser(object): | |||
| if not self.outlines: | |||
| logging.warning("Miss outlines") | |||
| return [(l, "") for l in lines], [] | |||
| return [(line, "") for line in lines], [] | |||
| def crop(self, ck, need_position): | |||
| raise NotImplementedError | |||
| @@ -15,21 +15,42 @@ import datetime | |||
| def refactor(cv): | |||
| for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]: | |||
| if n in cv and cv[n] is not None: del cv[n] | |||
| for n in [ | |||
| "raw_txt", | |||
| "parser_name", | |||
| "inference", | |||
| "ori_text", | |||
| "use_time", | |||
| "time_stat", | |||
| ]: | |||
| if n in cv and cv[n] is not None: | |||
| del cv[n] | |||
| cv["is_deleted"] = 0 | |||
| if "basic" not in cv: cv["basic"] = {} | |||
| if cv["basic"].get("photo2"): del cv["basic"]["photo2"] | |||
| if "basic" not in cv: | |||
| cv["basic"] = {} | |||
| if cv["basic"].get("photo2"): | |||
| del cv["basic"]["photo2"] | |||
| for n in ["education", "work", "certificate", "project", "language", "skill", "training"]: | |||
| if n not in cv or cv[n] is None: continue | |||
| if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()] | |||
| if type(cv[n]) != type([]): | |||
| for n in [ | |||
| "education", | |||
| "work", | |||
| "certificate", | |||
| "project", | |||
| "language", | |||
| "skill", | |||
| "training", | |||
| ]: | |||
| if n not in cv or cv[n] is None: | |||
| continue | |||
| if isinstance(cv[n], dict): | |||
| cv[n] = [v for _, v in cv[n].items()] | |||
| if not isinstance(cv[n], list): | |||
| del cv[n] | |||
| continue | |||
| vv = [] | |||
| for v in cv[n]: | |||
| if "external" in v and v["external"] is not None: del v["external"] | |||
| if "external" in v and v["external"] is not None: | |||
| del v["external"] | |||
| vv.append(v) | |||
| cv[n] = {str(i): vv[i] for i in range(len(vv))} | |||
| @@ -42,24 +63,44 @@ def refactor(cv): | |||
| cv["basic"][t] = cv["basic"][n] | |||
| del cv["basic"][n] | |||
| work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", "")) | |||
| edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", "")) | |||
| work = sorted( | |||
| [v for _, v in cv.get("work", {}).items()], | |||
| key=lambda x: x.get("start_time", ""), | |||
| ) | |||
| edu = sorted( | |||
| [v for _, v in cv.get("education", {}).items()], | |||
| key=lambda x: x.get("start_time", ""), | |||
| ) | |||
| if work: | |||
| cv["basic"]["work_start_time"] = work[0].get("start_time", "") | |||
| cv["basic"]["management_experience"] = 'Y' if any( | |||
| [w.get("management_experience", '') == 'Y' for w in work]) else 'N' | |||
| cv["basic"]["management_experience"] = ( | |||
| "Y" | |||
| if any([w.get("management_experience", "") == "Y" for w in work]) | |||
| else "N" | |||
| ) | |||
| cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0") | |||
| for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities", | |||
| "corporation_type", "scale", "corporation_name"]: | |||
| for n in [ | |||
| "annual_salary_from", | |||
| "annual_salary_to", | |||
| "industry_name", | |||
| "position_name", | |||
| "responsibilities", | |||
| "corporation_type", | |||
| "scale", | |||
| "corporation_name", | |||
| ]: | |||
| cv["basic"][n] = work[-1].get(n, "") | |||
| if edu: | |||
| for n in ["school_name", "discipline_name"]: | |||
| if n in edu[-1]: cv["basic"][n] = edu[-1][n] | |||
| if n in edu[-1]: | |||
| cv["basic"][n] = edu[-1][n] | |||
| cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |||
| if "contact" not in cv: cv["contact"] = {} | |||
| if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "") | |||
| return cv | |||
| if "contact" not in cv: | |||
| cv["contact"] = {} | |||
| if not cv["contact"].get("name"): | |||
| cv["contact"]["name"] = cv["basic"].get("name", "") | |||
| return cv | |||
| @@ -21,13 +21,18 @@ from . import regions | |||
| current_file_path = os.path.dirname(os.path.abspath(__file__)) | |||
| GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0) | |||
| GOODS = pd.read_csv( | |||
| os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0 | |||
| ).fillna(0) | |||
| GOODS["cid"] = GOODS["cid"].astype(str) | |||
| GOODS = GOODS.set_index(["cid"]) | |||
| CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")) | |||
| CORP_TKS = json.load( | |||
| open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r") | |||
| ) | |||
| GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r")) | |||
| CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r")) | |||
| def baike(cid, default_v=0): | |||
| global GOODS | |||
| try: | |||
| @@ -39,27 +44,41 @@ def baike(cid, default_v=0): | |||
| def corpNorm(nm, add_region=True): | |||
| global CORP_TKS | |||
| if not nm or type(nm)!=type(""):return "" | |||
| if not nm or isinstance(nm, str): | |||
| return "" | |||
| nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower() | |||
| nm = re.sub(r"&", "&", nm) | |||
| nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm) | |||
| nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE) | |||
| nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE) | |||
| if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm | |||
| nm = re.sub( | |||
| r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE | |||
| ) | |||
| nm = re.sub( | |||
| r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", | |||
| "", | |||
| nm, | |||
| 10000, | |||
| re.IGNORECASE, | |||
| ) | |||
| if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])): | |||
| return nm | |||
| tks = rag_tokenizer.tokenize(nm).split() | |||
| reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)] | |||
| reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)] | |||
| nm = "" | |||
| for t in tks: | |||
| if regions.isName(t) or t in CORP_TKS:continue | |||
| if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " " | |||
| if regions.isName(t) or t in CORP_TKS: | |||
| continue | |||
| if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm): | |||
| nm += " " | |||
| nm += t | |||
| r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip()) | |||
| if r:nm = r.group(1) | |||
| if r: | |||
| nm = r.group(1) | |||
| r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip()) | |||
| if r:nm = r.group(1) | |||
| return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "") | |||
| if r: | |||
| nm = r.group(1) | |||
| return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "") | |||
| def rmNoise(n): | |||
| @@ -67,33 +86,40 @@ def rmNoise(n): | |||
| n = re.sub(r"[,. &()()]+", "", n) | |||
| return n | |||
| GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP]) | |||
| for c,v in CORP_TAG.items(): | |||
| for c, v in CORP_TAG.items(): | |||
| cc = corpNorm(rmNoise(c), False) | |||
| if not cc: | |||
| logging.debug(c) | |||
| CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()} | |||
| CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()} | |||
| def is_good(nm): | |||
| global GOOD_CORP | |||
| if nm.find("外派")>=0:return False | |||
| if nm.find("外派") >= 0: | |||
| return False | |||
| nm = rmNoise(nm) | |||
| nm = corpNorm(nm, False) | |||
| for n in GOOD_CORP: | |||
| if re.match(r"[0-9a-zA-Z]+$", n): | |||
| if n == nm: return True | |||
| elif nm.find(n)>=0:return True | |||
| if n == nm: | |||
| return True | |||
| elif nm.find(n) >= 0: | |||
| return True | |||
| return False | |||
| def corp_tag(nm): | |||
| global CORP_TAG | |||
| nm = rmNoise(nm) | |||
| nm = corpNorm(nm, False) | |||
| for n in CORP_TAG.keys(): | |||
| if re.match(r"[0-9a-zA-Z., ]+$", n): | |||
| if n == nm: return CORP_TAG[n] | |||
| elif nm.find(n)>=0: | |||
| if len(n)<3 and len(nm)/len(n)>=2:continue | |||
| if n == nm: | |||
| return CORP_TAG[n] | |||
| elif nm.find(n) >= 0: | |||
| if len(n) < 3 and len(nm) / len(n) >= 2: | |||
| continue | |||
| return CORP_TAG[n] | |||
| return [] | |||
| @@ -11,27 +11,31 @@ | |||
| # limitations under the License. | |||
| # | |||
| TBL = {"94":"EMBA", | |||
| "6":"MBA", | |||
| "95":"MPA", | |||
| "92":"专升本", | |||
| "4":"专科", | |||
| "90":"中专", | |||
| "91":"中技", | |||
| "86":"初中", | |||
| "3":"博士", | |||
| "10":"博士后", | |||
| "1":"本科", | |||
| "2":"硕士", | |||
| "87":"职高", | |||
| "89":"高中" | |||
| TBL = { | |||
| "94": "EMBA", | |||
| "6": "MBA", | |||
| "95": "MPA", | |||
| "92": "专升本", | |||
| "4": "专科", | |||
| "90": "中专", | |||
| "91": "中技", | |||
| "86": "初中", | |||
| "3": "博士", | |||
| "10": "博士后", | |||
| "1": "本科", | |||
| "2": "硕士", | |||
| "87": "职高", | |||
| "89": "高中", | |||
| } | |||
| TBL_ = {v:k for k,v in TBL.items()} | |||
| TBL_ = {v: k for k, v in TBL.items()} | |||
| def get_name(id): | |||
| return TBL.get(str(id), "") | |||
| def get_id(nm): | |||
| if not nm:return "" | |||
| if not nm: | |||
| return "" | |||
| return TBL_.get(nm.upper().strip(), "") | |||
| @@ -16,8 +16,11 @@ import json | |||
| import re | |||
| import copy | |||
| import pandas as pd | |||
| current_file_path = os.path.dirname(os.path.abspath(__file__)) | |||
| TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("") | |||
| TBL = pd.read_csv( | |||
| os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0 | |||
| ).fillna("") | |||
| TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip()) | |||
| GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r")) | |||
| GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH]) | |||
| @@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH]) | |||
| def loadRank(fnm): | |||
| global TBL | |||
| TBL["rank"] = 1000000 | |||
| with open(fnm, "r", encoding='utf-8') as f: | |||
| with open(fnm, "r", encoding="utf-8") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l:break | |||
| l = l.strip("\n").split(",") | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| line = line.strip("\n").split(",") | |||
| try: | |||
| nm,rk = l[0].strip(),int(l[1]) | |||
| #assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>" | |||
| nm, rk = line[0].strip(), int(line[1]) | |||
| # assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>" | |||
| TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk | |||
| except Exception: | |||
| pass | |||
| @@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv")) | |||
| def split(txt): | |||
| tks = [] | |||
| for t in re.sub(r"[ \t]+", " ",txt).split(): | |||
| if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ | |||
| re.match(r"[a-zA-Z]", t) and tks: | |||
| for t in re.sub(r"[ \t]+", " ", txt).split(): | |||
| if ( | |||
| tks | |||
| and re.match(r".*[a-zA-Z]$", tks[-1]) | |||
| and re.match(r"[a-zA-Z]", t) | |||
| and tks | |||
| ): | |||
| tks[-1] = tks[-1] + " " + t | |||
| else:tks.append(t) | |||
| else: | |||
| tks.append(t) | |||
| return tks | |||
| def select(nm): | |||
| global TBL | |||
| if not nm:return | |||
| if isinstance(nm, list):nm = str(nm[0]) | |||
| if not nm: | |||
| return | |||
| if isinstance(nm, list): | |||
| nm = str(nm[0]) | |||
| nm = split(nm)[0] | |||
| nm = str(nm).lower().strip() | |||
| nm = re.sub(r"[((][^()()]+[))]", "", nm.lower()) | |||
| nm = re.sub(r"(^the |[,.&()();;·]+|^(英国|美国|瑞士))", "", nm) | |||
| nm = re.sub(r"大学.*学院", "大学", nm) | |||
| tbl = copy.deepcopy(TBL) | |||
| tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+"))) | |||
| res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))] | |||
| if res.empty:return | |||
| tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+"))) | |||
| res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)] | |||
| if res.empty: | |||
| return | |||
| return json.loads(res.to_json(orient="records"))[0] | |||
| @@ -74,4 +86,3 @@ def is_good(nm): | |||
| nm = re.sub(r"[((][^()()]+[))]", "", nm.lower()) | |||
| nm = re.sub(r"[''`‘’“”,. &()();;]+", "", nm) | |||
| return nm in GOOD_SCH | |||
| @@ -25,7 +25,8 @@ from xpinyin import Pinyin | |||
| from contextlib import contextmanager | |||
| class TimeoutException(Exception): pass | |||
| class TimeoutException(Exception): | |||
| pass | |||
| @contextmanager | |||
| @@ -50,8 +51,10 @@ def rmHtmlTag(line): | |||
| def highest_degree(dg): | |||
| if not dg: return "" | |||
| if type(dg) == type(""): dg = [dg] | |||
| if not dg: | |||
| return "" | |||
| if isinstance(dg, str): | |||
| dg = [dg] | |||
| m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8} | |||
| return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0] | |||
| @@ -68,10 +71,12 @@ def forEdu(cv): | |||
| for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))): | |||
| e = {} | |||
| if n.get("end_time"): | |||
| if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"] | |||
| if n["end_time"] > edu_end_dt: | |||
| edu_end_dt = n["end_time"] | |||
| try: | |||
| dt = n["end_time"] | |||
| if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt) | |||
| if re.match(r"[0-9]{9,}", dt): | |||
| dt = turnTm2Dt(dt) | |||
| y, m, d = getYMD(dt) | |||
| ed_dt.append(str(y)) | |||
| e["end_dt_kwd"] = str(y) | |||
| @@ -80,7 +85,8 @@ def forEdu(cv): | |||
| if n.get("start_time"): | |||
| try: | |||
| dt = n["start_time"] | |||
| if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt) | |||
| if re.match(r"[0-9]{9,}", dt): | |||
| dt = turnTm2Dt(dt) | |||
| y, m, d = getYMD(dt) | |||
| st_dt.append(str(y)) | |||
| e["start_dt_kwd"] = str(y) | |||
| @@ -89,13 +95,20 @@ def forEdu(cv): | |||
| r = schools.select(n.get("school_name", "")) | |||
| if r: | |||
| if str(r.get("type", "")) == "1": fea.append("211") | |||
| if str(r.get("type", "")) == "2": fea.append("211") | |||
| if str(r.get("is_abroad", "")) == "1": fea.append("留学") | |||
| if str(r.get("is_double_first", "")) == "1": fea.append("双一流") | |||
| if str(r.get("is_985", "")) == "1": fea.append("985") | |||
| if str(r.get("is_world_known", "")) == "1": fea.append("海外知名") | |||
| if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"] | |||
| if str(r.get("type", "")) == "1": | |||
| fea.append("211") | |||
| if str(r.get("type", "")) == "2": | |||
| fea.append("211") | |||
| if str(r.get("is_abroad", "")) == "1": | |||
| fea.append("留学") | |||
| if str(r.get("is_double_first", "")) == "1": | |||
| fea.append("双一流") | |||
| if str(r.get("is_985", "")) == "1": | |||
| fea.append("985") | |||
| if str(r.get("is_world_known", "")) == "1": | |||
| fea.append("海外知名") | |||
| if r.get("rank") and cv["school_rank_int"] > r["rank"]: | |||
| cv["school_rank_int"] = r["rank"] | |||
| if n.get("school_name") and isinstance(n["school_name"], str): | |||
| sch.append(re.sub(r"(211|985|重点大学|[,&;;-])", "", n["school_name"])) | |||
| @@ -106,22 +119,25 @@ def forEdu(cv): | |||
| maj.append(n["discipline_name"]) | |||
| e["major_kwd"] = n["discipline_name"] | |||
| if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1" | |||
| if not n.get("degree") and "985" in fea and not first_fea: | |||
| n["degree"] = "1" | |||
| if n.get("degree"): | |||
| d = degrees.get_name(n["degree"]) | |||
| if d: e["degree_kwd"] = d | |||
| if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", | |||
| n.get( | |||
| "school_name", | |||
| ""))): d = "专升本" | |||
| if d: deg.append(d) | |||
| if d: | |||
| e["degree_kwd"] = d | |||
| if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))): | |||
| d = "专升本" | |||
| if d: | |||
| deg.append(d) | |||
| # for first degree | |||
| if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]: | |||
| fdeg = [d] | |||
| if n.get("school_name"): fsch = [n["school_name"]] | |||
| if n.get("discipline_name"): fmaj = [n["discipline_name"]] | |||
| if n.get("school_name"): | |||
| fsch = [n["school_name"]] | |||
| if n.get("discipline_name"): | |||
| fmaj = [n["discipline_name"]] | |||
| first_fea = copy.deepcopy(fea) | |||
| edu_nst.append(e) | |||
| @@ -140,16 +156,26 @@ def forEdu(cv): | |||
| else: | |||
| cv["sch_rank_kwd"].append("一般学校") | |||
| if edu_nst: cv["edu_nst"] = edu_nst | |||
| if fea: cv["edu_fea_kwd"] = list(set(fea)) | |||
| if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea)) | |||
| if maj: cv["major_kwd"] = maj | |||
| if fsch: cv["first_school_name_kwd"] = fsch | |||
| if fdeg: cv["first_degree_kwd"] = fdeg | |||
| if fmaj: cv["first_major_kwd"] = fmaj | |||
| if st_dt: cv["edu_start_kwd"] = st_dt | |||
| if ed_dt: cv["edu_end_kwd"] = ed_dt | |||
| if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt]) | |||
| if edu_nst: | |||
| cv["edu_nst"] = edu_nst | |||
| if fea: | |||
| cv["edu_fea_kwd"] = list(set(fea)) | |||
| if first_fea: | |||
| cv["edu_first_fea_kwd"] = list(set(first_fea)) | |||
| if maj: | |||
| cv["major_kwd"] = maj | |||
| if fsch: | |||
| cv["first_school_name_kwd"] = fsch | |||
| if fdeg: | |||
| cv["first_degree_kwd"] = fdeg | |||
| if fmaj: | |||
| cv["first_major_kwd"] = fmaj | |||
| if st_dt: | |||
| cv["edu_start_kwd"] = st_dt | |||
| if ed_dt: | |||
| cv["edu_end_kwd"] = ed_dt | |||
| if ed_dt: | |||
| cv["edu_end_int"] = max([int(t) for t in ed_dt]) | |||
| if deg: | |||
| if "本科" in deg and "专科" in deg: | |||
| deg.append("专升本") | |||
| @@ -158,8 +184,10 @@ def forEdu(cv): | |||
| cv["highest_degree_kwd"] = highest_degree(deg) | |||
| if edu_end_dt: | |||
| try: | |||
| if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt) | |||
| if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today())) | |||
| if re.match(r"[0-9]{9,}", edu_end_dt): | |||
| edu_end_dt = turnTm2Dt(edu_end_dt) | |||
| if edu_end_dt.strip("\n") == "至今": | |||
| edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today())) | |||
| y, m, d = getYMD(edu_end_dt) | |||
| cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) | |||
| except Exception as e: | |||
| @@ -171,7 +199,8 @@ def forEdu(cv): | |||
| or not cv.get("degree_kwd"): | |||
| for c in sch: | |||
| if schools.is_good(c): | |||
| if "tag_kwd" not in cv: cv["tag_kwd"] = [] | |||
| if "tag_kwd" not in cv: | |||
| cv["tag_kwd"] = [] | |||
| cv["tag_kwd"].append("好学校") | |||
| cv["tag_kwd"].append("好学历") | |||
| break | |||
| @@ -180,28 +209,39 @@ def forEdu(cv): | |||
| any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \ | |||
| or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \ | |||
| or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]): | |||
| if "tag_kwd" not in cv: cv["tag_kwd"] = [] | |||
| if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历") | |||
| if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj)) | |||
| if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch)) | |||
| if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch)) | |||
| if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj)) | |||
| if "tag_kwd" not in cv: | |||
| cv["tag_kwd"] = [] | |||
| if "好学历" not in cv["tag_kwd"]: | |||
| cv["tag_kwd"].append("好学历") | |||
| if cv.get("major_kwd"): | |||
| cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj)) | |||
| if cv.get("school_name_kwd"): | |||
| cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch)) | |||
| if cv.get("first_school_name_kwd"): | |||
| cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch)) | |||
| if cv.get("first_major_kwd"): | |||
| cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj)) | |||
| return cv | |||
| def forProj(cv): | |||
| if not cv.get("project_obj"): return cv | |||
| if not cv.get("project_obj"): | |||
| return cv | |||
| pro_nms, desc = [], [] | |||
| for i, n in enumerate( | |||
| sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "", | |||
| sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "", | |||
| reverse=True)): | |||
| if n.get("name"): pro_nms.append(n["name"]) | |||
| if n.get("describe"): desc.append(str(n["describe"])) | |||
| if n.get("responsibilities"): desc.append(str(n["responsibilities"])) | |||
| if n.get("achivement"): desc.append(str(n["achivement"])) | |||
| if n.get("name"): | |||
| pro_nms.append(n["name"]) | |||
| if n.get("describe"): | |||
| desc.append(str(n["describe"])) | |||
| if n.get("responsibilities"): | |||
| desc.append(str(n["responsibilities"])) | |||
| if n.get("achivement"): | |||
| desc.append(str(n["achivement"])) | |||
| if pro_nms: | |||
| # cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms)) | |||
| @@ -233,15 +273,16 @@ def forWork(cv): | |||
| work_st_tm = "" | |||
| corp_tags = [] | |||
| for i, n in enumerate( | |||
| sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "", | |||
| sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "", | |||
| reverse=True)): | |||
| if type(n) == type(""): | |||
| if isinstance(n, str): | |||
| try: | |||
| n = json_loads(n) | |||
| except Exception: | |||
| continue | |||
| if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"] | |||
| if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): | |||
| work_st_tm = n["start_time"] | |||
| for c in flds: | |||
| if not n.get(c) or str(n[c]) == '0': | |||
| fea[c].append("") | |||
| @@ -262,14 +303,18 @@ def forWork(cv): | |||
| fea[c].append(rmHtmlTag(str(n[c]).lower())) | |||
| y, m, d = getYMD(n.get("start_time")) | |||
| if not y or not m: continue | |||
| if not y or not m: | |||
| continue | |||
| st = "%s-%02d-%02d" % (y, int(m), int(d)) | |||
| latest_job_tm = st | |||
| y, m, d = getYMD(n.get("end_time")) | |||
| if (not y or not m) and i > 0: continue | |||
| if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", ""))) | |||
| if not y or not m: continue | |||
| if (not y or not m) and i > 0: | |||
| continue | |||
| if not y or not m or int(y) > 2022: | |||
| y, m, d = getYMD(str(n.get("updated_at", ""))) | |||
| if not y or not m: | |||
| continue | |||
| ed = "%s-%02d-%02d" % (y, int(m), int(d)) | |||
| try: | |||
| @@ -279,22 +324,28 @@ def forWork(cv): | |||
| if n.get("scale"): | |||
| r = re.search(r"^([0-9]+)", str(n["scale"])) | |||
| if r: scales.append(int(r.group(1))) | |||
| if r: | |||
| scales.append(int(r.group(1))) | |||
| if goodcorp: | |||
| if "tag_kwd" not in cv: cv["tag_kwd"] = [] | |||
| if "tag_kwd" not in cv: | |||
| cv["tag_kwd"] = [] | |||
| cv["tag_kwd"].append("好公司") | |||
| if goodcorp_: | |||
| if "tag_kwd" not in cv: cv["tag_kwd"] = [] | |||
| if "tag_kwd" not in cv: | |||
| cv["tag_kwd"] = [] | |||
| cv["tag_kwd"].append("好公司(曾)") | |||
| if corp_tags: | |||
| if "tag_kwd" not in cv: cv["tag_kwd"] = [] | |||
| if "tag_kwd" not in cv: | |||
| cv["tag_kwd"] = [] | |||
| cv["tag_kwd"].extend(corp_tags) | |||
| cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)] | |||
| if latest_job_tm: cv["latest_job_dt"] = latest_job_tm | |||
| if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"] | |||
| if latest_job_tm: | |||
| cv["latest_job_dt"] = latest_job_tm | |||
| if fea["corporation_id"]: | |||
| cv["corporation_id"] = fea["corporation_id"] | |||
| if fea["position_name"]: | |||
| cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0]) | |||
| @@ -317,18 +368,23 @@ def forWork(cv): | |||
| cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0]) | |||
| cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:])) | |||
| if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if | |||
| if fea["subordinates_count"]: | |||
| fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if | |||
| re.match(r"[^0-9]+$", str(i))] | |||
| if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"]) | |||
| if fea["subordinates_count"]: | |||
| cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"]) | |||
| if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])] | |||
| if not cv.get("corporation_id"): cv["corporation_id"] = [] | |||
| if isinstance(cv.get("corporation_id"), int): | |||
| cv["corporation_id"] = [str(cv["corporation_id"])] | |||
| if not cv.get("corporation_id"): | |||
| cv["corporation_id"] = [] | |||
| for i in cv.get("corporation_id", []): | |||
| cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0) | |||
| if work_st_tm: | |||
| try: | |||
| if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm) | |||
| if re.match(r"[0-9]{9,}", work_st_tm): | |||
| work_st_tm = turnTm2Dt(work_st_tm) | |||
| y, m, d = getYMD(work_st_tm) | |||
| cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) | |||
| except Exception as e: | |||
| @@ -339,28 +395,37 @@ def forWork(cv): | |||
| cv["dua_flt"] = np.mean(duas) | |||
| cv["cur_dua_int"] = duas[0] | |||
| cv["job_num_int"] = len(duas) | |||
| if scales: cv["scale_flt"] = np.max(scales) | |||
| if scales: | |||
| cv["scale_flt"] = np.max(scales) | |||
| return cv | |||
| def turnTm2Dt(b): | |||
| if not b: return | |||
| if not b: | |||
| return | |||
| b = str(b).strip() | |||
| if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10]))) | |||
| if re.match(r"[0-9]{10,}", b): | |||
| b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10]))) | |||
| return b | |||
| def getYMD(b): | |||
| y, m, d = "", "", "01" | |||
| if not b: return (y, m, d) | |||
| if not b: | |||
| return (y, m, d) | |||
| b = turnTm2Dt(b) | |||
| if re.match(r"[0-9]{4}", b): y = int(b[:4]) | |||
| if re.match(r"[0-9]{4}", b): | |||
| y = int(b[:4]) | |||
| r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b) | |||
| if r: m = r.group(1) | |||
| if r: | |||
| m = r.group(1) | |||
| r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b) | |||
| if r: d = r.group(1) | |||
| if not d or int(d) == 0 or int(d) > 31: d = "1" | |||
| if not m or int(m) > 12 or int(m) < 1: m = "1" | |||
| if r: | |||
| d = r.group(1) | |||
| if not d or int(d) == 0 or int(d) > 31: | |||
| d = "1" | |||
| if not m or int(m) > 12 or int(m) < 1: | |||
| m = "1" | |||
| return (y, m, d) | |||
| @@ -369,7 +434,8 @@ def birth(cv): | |||
| cv["integerity_flt"] *= 0.9 | |||
| return cv | |||
| y, m, d = getYMD(cv["birth"]) | |||
| if not m or not y: return cv | |||
| if not m or not y: | |||
| return cv | |||
| b = "%s-%02d-%02d" % (y, int(m), int(d)) | |||
| cv["birth_dt"] = b | |||
| cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d)) | |||
| @@ -380,7 +446,8 @@ def birth(cv): | |||
| def parse(cv): | |||
| for k in cv.keys(): | |||
| if cv[k] == '\\N': cv[k] = '' | |||
| if cv[k] == '\\N': | |||
| cv[k] = '' | |||
| # cv = cv.asDict() | |||
| tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names", | |||
| "expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name", | |||
| @@ -402,9 +469,12 @@ def parse(cv): | |||
| rmkeys = [] | |||
| for k in cv.keys(): | |||
| if cv[k] is None: rmkeys.append(k) | |||
| if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k) | |||
| for k in rmkeys: del cv[k] | |||
| if cv[k] is None: | |||
| rmkeys.append(k) | |||
| if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0: | |||
| rmkeys.append(k) | |||
| for k in rmkeys: | |||
| del cv[k] | |||
| integerity = 0. | |||
| flds_num = 0. | |||
| @@ -414,7 +484,8 @@ def parse(cv): | |||
| flds_num += len(flds) | |||
| for f in flds: | |||
| v = str(cv.get(f, "")) | |||
| if len(v) > 0 and v != '0' and v != '[]': integerity += 1 | |||
| if len(v) > 0 and v != '0' and v != '[]': | |||
| integerity += 1 | |||
| hasValues(tks_fld) | |||
| hasValues(small_tks_fld) | |||
| @@ -433,7 +504,8 @@ def parse(cv): | |||
| (r"[ ()\(\)人/·0-9-]+", ""), | |||
| (r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]: | |||
| cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE) | |||
| if len(cv["corporation_type"]) < 2: del cv["corporation_type"] | |||
| if len(cv["corporation_type"]) < 2: | |||
| del cv["corporation_type"] | |||
| if cv.get("political_status"): | |||
| for p, r in [ | |||
| @@ -441,9 +513,11 @@ def parse(cv): | |||
| (r".*(无党派|公民).*", "群众"), | |||
| (r".*团员.*", "团员")]: | |||
| cv["political_status"] = re.sub(p, r, cv["political_status"]) | |||
| if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"] | |||
| if not re.search(r"[党团群]", cv["political_status"]): | |||
| del cv["political_status"] | |||
| if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"])) | |||
| if cv.get("phone"): | |||
| cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"])) | |||
| keys = list(cv.keys()) | |||
| for k in keys: | |||
| @@ -454,9 +528,11 @@ def parse(cv): | |||
| cv[k] = [a for _, a in cv[k].items()] | |||
| nms = [] | |||
| for n in cv[k]: | |||
| if type(n) != type({}) or "name" not in n or not n.get("name"): continue | |||
| if not isinstance(n, dict) or "name" not in n or not n.get("name"): | |||
| continue | |||
| n["name"] = re.sub(r"((442)|\t )", "", n["name"]).strip().lower() | |||
| if not n["name"]: continue | |||
| if not n["name"]: | |||
| continue | |||
| nms.append(n["name"]) | |||
| if nms: | |||
| t = k[:-4] | |||
| @@ -469,15 +545,18 @@ def parse(cv): | |||
| # tokenize fields | |||
| if k in tks_fld: | |||
| cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k]) | |||
| if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"]) | |||
| if k in small_tks_fld: | |||
| cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"]) | |||
| # keyword fields | |||
| if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower() | |||
| if k in kwd_fld: | |||
| cv[f"{k}_kwd"] = [n.lower() | |||
| for n in re.split(r"[\t,,;;. ]", | |||
| re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k]) | |||
| ) if n] | |||
| if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k] | |||
| if k in num_fld and cv.get(k): | |||
| cv[f"{k}_int"] = cv[k] | |||
| cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "") | |||
| # for name field | |||
| @@ -501,10 +580,12 @@ def parse(cv): | |||
| cv["name_py_pref0_tks"] = "" | |||
| cv["name_py_pref_tks"] = "" | |||
| for py in PY.get_pinyins(nm[:20], ''): | |||
| for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i] | |||
| for i in range(2, len(py) + 1): | |||
| cv["name_py_pref_tks"] += " " + py[:i] | |||
| for py in PY.get_pinyins(nm[:20], ' '): | |||
| py = py.split() | |||
| for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i]) | |||
| for i in range(1, len(py) + 1): | |||
| cv["name_py_pref0_tks"] += " " + "".join(py[:i]) | |||
| cv["name_kwd"] = name | |||
| cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3] | |||
| @@ -526,22 +607,30 @@ def parse(cv): | |||
| cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S') | |||
| else: | |||
| y, m, d = getYMD(str(cv.get("updated_at", ""))) | |||
| if not y: y = "2012" | |||
| if not m: m = "01" | |||
| if not d: d = "01" | |||
| if not y: | |||
| y = "2012" | |||
| if not m: | |||
| m = "01" | |||
| if not d: | |||
| d = "01" | |||
| cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d)) | |||
| # long text tokenize | |||
| if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"])) | |||
| if cv.get("responsibilities"): | |||
| cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"])) | |||
| # for yes or no field | |||
| fea = [] | |||
| for f, y, n in is_fld: | |||
| if f not in cv: continue | |||
| if cv[f] == '是': fea.append(y) | |||
| if cv[f] == '否': fea.append(n) | |||
| if f not in cv: | |||
| continue | |||
| if cv[f] == '是': | |||
| fea.append(y) | |||
| if cv[f] == '否': | |||
| fea.append(n) | |||
| if fea: cv["tag_kwd"] = fea | |||
| if fea: | |||
| cv["tag_kwd"] = fea | |||
| cv = forEdu(cv) | |||
| cv = forProj(cv) | |||
| @@ -550,9 +639,11 @@ def parse(cv): | |||
| cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])] | |||
| for i in range(len(cv["corp_proj_sch_deg_kwd"])): | |||
| for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j | |||
| for j in cv.get("sch_rank_kwd", []): | |||
| cv["corp_proj_sch_deg_kwd"][i] += "+" + j | |||
| for i in range(len(cv["corp_proj_sch_deg_kwd"])): | |||
| if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"] | |||
| if cv.get("highest_degree_kwd"): | |||
| cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"] | |||
| try: | |||
| if not cv.get("work_exp_flt") and cv.get("work_start_time"): | |||
| @@ -565,17 +656,21 @@ def parse(cv): | |||
| cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y) | |||
| except Exception as e: | |||
| logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time"))) | |||
| if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12. | |||
| if "work_exp_flt" not in cv and cv.get("work_experience", 0): | |||
| cv["work_exp_flt"] = int(cv["work_experience"]) / 12. | |||
| keys = list(cv.keys()) | |||
| for k in keys: | |||
| if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k] | |||
| if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): | |||
| del cv[k] | |||
| for k in cv.keys(): | |||
| if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue | |||
| if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list): | |||
| continue | |||
| cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']])) | |||
| keys = [k for k in cv.keys() if re.search(r"_feas*$", k)] | |||
| for k in keys: | |||
| if cv[k] <= 0: del cv[k] | |||
| if cv[k] <= 0: | |||
| del cv[k] | |||
| cv["tob_resume_id"] = str(cv["tob_resume_id"]) | |||
| cv["id"] = cv["tob_resume_id"] | |||
| @@ -592,5 +687,6 @@ def dealWithInt64(d): | |||
| if isinstance(d, list): | |||
| d = [dealWithInt64(t) for t in d] | |||
| if isinstance(d, np.integer): d = int(d) | |||
| if isinstance(d, np.integer): | |||
| d = int(d) | |||
| return d | |||
| @@ -51,6 +51,7 @@ class RAGFlowTxtParser: | |||
| dels = [d for d in dels if d] | |||
| dels = "|".join(dels) | |||
| secs = re.split(r"(%s)" % dels, txt) | |||
| for sec in secs: add_chunk(sec) | |||
| for sec in secs: | |||
| add_chunk(sec) | |||
| return [[c, ""] for c in cks] | |||
| @@ -18,7 +18,6 @@ from .recognizer import Recognizer | |||
| from .layout_recognizer import LayoutRecognizer | |||
| from .table_structure_recognizer import TableStructureRecognizer | |||
| def init_in_out(args): | |||
| from PIL import Image | |||
| import os | |||
| @@ -47,7 +46,7 @@ def init_in_out(args): | |||
| try: | |||
| images.append(Image.open(fnm)) | |||
| outputs.append(os.path.split(fnm)[-1]) | |||
| except Exception as e: | |||
| except Exception: | |||
| traceback.print_exc() | |||
| if os.path.isdir(args.inputs): | |||
| @@ -56,6 +55,16 @@ def init_in_out(args): | |||
| else: | |||
| images_and_outputs(args.inputs) | |||
| for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i]) | |||
| for i in range(len(outputs)): | |||
| outputs[i] = os.path.join(args.output_dir, outputs[i]) | |||
| return images, outputs | |||
| return images, outputs | |||
| __all__ = [ | |||
| "OCR", | |||
| "Recognizer", | |||
| "LayoutRecognizer", | |||
| "TableStructureRecognizer", | |||
| "init_in_out", | |||
| ] | |||
| @@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer): | |||
| get_project_base_directory(), | |||
| "rag/res/deepdoc") | |||
| super().__init__(self.labels, domain, model_dir) | |||
| except Exception as e: | |||
| except Exception: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", | |||
| local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | |||
| local_dir_use_symlinks=False) | |||
| @@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer): | |||
| "page_number": pn, | |||
| } for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts] | |||
| lts = self.sort_Y_firstly(lts, np.mean( | |||
| [l["bottom"] - l["top"] for l in lts]) / 2) | |||
| [lt["bottom"] - lt["top"] for lt in lts]) / 2) | |||
| lts = self.layouts_cleanup(bxs, lts) | |||
| page_layout.append(lts) | |||
| @@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from .operators import * | |||
| import math | |||
| import numpy as np | |||
| import cv2 | |||
| import onnxruntime as ort | |||
| from .postprocess import build_post_process | |||
| @@ -484,7 +486,7 @@ class OCR(object): | |||
| "rag/res/deepdoc") | |||
| self.text_detector = TextDetector(model_dir) | |||
| self.text_recognizer = TextRecognizer(model_dir) | |||
| except Exception as e: | |||
| except Exception: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", | |||
| local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | |||
| local_dir_use_symlinks=False) | |||
| @@ -232,7 +232,7 @@ class LinearResize(object): | |||
| """ | |||
| assert len(self.target_size) == 2 | |||
| assert self.target_size[0] > 0 and self.target_size[1] > 0 | |||
| im_channel = im.shape[2] | |||
| _im_channel = im.shape[2] | |||
| im_scale_y, im_scale_x = self.generate_scale(im) | |||
| im = cv2.resize( | |||
| im, | |||
| @@ -255,7 +255,7 @@ class LinearResize(object): | |||
| im_scale_y: the resize ratio of Y | |||
| """ | |||
| origin_shape = im.shape[:2] | |||
| im_c = im.shape[2] | |||
| _im_c = im.shape[2] | |||
| if self.keep_ratio: | |||
| im_size_min = np.min(origin_shape) | |||
| im_size_max = np.max(origin_shape) | |||
| @@ -581,7 +581,7 @@ class SRResize(object): | |||
| return data | |||
| images_HR = data["image_hr"] | |||
| label_strs = data["label"] | |||
| _label_strs = data["label"] | |||
| transform = ResizeNormalize((imgW, imgH)) | |||
| images_HR = transform(images_HR) | |||
| data["img_hr"] = images_HR | |||
| @@ -121,7 +121,7 @@ class DBPostProcess(object): | |||
| outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, | |||
| cv2.CHAIN_APPROX_SIMPLE) | |||
| if len(outs) == 3: | |||
| img, contours, _ = outs[0], outs[1], outs[2] | |||
| _img, contours, _ = outs[0], outs[1], outs[2] | |||
| elif len(outs) == 2: | |||
| contours, _ = outs[0], outs[1] | |||
| @@ -13,15 +13,18 @@ | |||
| import logging | |||
| import os | |||
| import math | |||
| import numpy as np | |||
| import cv2 | |||
| from copy import deepcopy | |||
| import onnxruntime as ort | |||
| from huggingface_hub import snapshot_download | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from .operators import * | |||
| class Recognizer(object): | |||
| def __init__(self, label_list, task_name, model_dir=None): | |||
| """ | |||
| @@ -277,7 +280,8 @@ class Recognizer(object): | |||
| return | |||
| min_dis, min_i = 1000000, None | |||
| for i,b in enumerate(boxes): | |||
| if box.get("layoutno", "0") != b.get("layoutno", "0"): continue | |||
| if box.get("layoutno", "0") != b.get("layoutno", "0"): | |||
| continue | |||
| dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) | |||
| if dis < min_dis: | |||
| min_i = i | |||
| @@ -402,7 +406,8 @@ class Recognizer(object): | |||
| scores = np.max(boxes[:, 4:], axis=1) | |||
| boxes = boxes[scores > thr, :] | |||
| scores = scores[scores > thr] | |||
| if len(boxes) == 0: return [] | |||
| if len(boxes) == 0: | |||
| return [] | |||
| # Get the class with the highest confidence | |||
| class_ids = np.argmax(boxes[:, 4:], axis=1) | |||
| @@ -432,7 +437,8 @@ class Recognizer(object): | |||
| for i in range(len(image_list)): | |||
| if not isinstance(image_list[i], np.ndarray): | |||
| imgs.append(np.array(image_list[i])) | |||
| else: imgs.append(image_list[i]) | |||
| else: | |||
| imgs.append(image_list[i]) | |||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||
| for i in range(batch_loop_cnt): | |||
| @@ -88,7 +88,8 @@ class CommunityReportsExtractor: | |||
| ("findings", list), | |||
| ("rating", float), | |||
| ("rating_explanation", str), | |||
| ]): continue | |||
| ]): | |||
| continue | |||
| response["weight"] = weight | |||
| response["entities"] = ents | |||
| except Exception as e: | |||
| @@ -100,7 +101,8 @@ class CommunityReportsExtractor: | |||
| res_str.append(self._get_text_output(response)) | |||
| res_dict.append(response) | |||
| over += 1 | |||
| if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |||
| if callback: | |||
| callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |||
| return CommunityReportsResult( | |||
| structured_output=res_dict, | |||
| @@ -8,6 +8,7 @@ Reference: | |||
| from typing import Any | |||
| import numpy as np | |||
| import networkx as nx | |||
| from dataclasses import dataclass | |||
| from graphrag.leiden import stable_largest_connected_component | |||
| @@ -129,9 +129,11 @@ class GraphExtractor: | |||
| source_doc_map[doc_index] = text | |||
| all_records[doc_index] = result | |||
| total_token_count += token_count | |||
| if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") | |||
| if callback: | |||
| callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") | |||
| except Exception as e: | |||
| if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e))) | |||
| if callback: | |||
| callback(msg="Knowledge graph extraction error:{}".format(str(e))) | |||
| logging.exception("error extracting graph") | |||
| self._on_error( | |||
| e, | |||
| @@ -164,7 +166,8 @@ class GraphExtractor: | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||
| gen_conf = {"temperature": 0.3} | |||
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| if response.find("**ERROR**") >= 0: raise Exception(response) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| token_count = num_tokens_from_string(text + response) | |||
| results = response or "" | |||
| @@ -175,7 +178,8 @@ class GraphExtractor: | |||
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) | |||
| history.append({"role": "user", "content": text}) | |||
| response = self._llm.chat("", history, gen_conf) | |||
| if response.find("**ERROR**") >=0: raise Exception(response) | |||
| if response.find("**ERROR**") >=0: | |||
| raise Exception(response) | |||
| results += response or "" | |||
| # if this is the final glean, don't bother updating the continuation flag | |||
| @@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en | |||
| callback(0.75, "Extracting mind graph.") | |||
| mindmap = MindMapExtractor(llm_bdl) | |||
| mg = mindmap(_chunks).output | |||
| if not len(mg.keys()): return chunks | |||
| if not len(mg.keys()): | |||
| return chunks | |||
| logging.debug(json.dumps(mg, ensure_ascii=False, indent=2)) | |||
| chunks.append( | |||
| @@ -78,7 +78,8 @@ def _compute_leiden_communities( | |||
| ) -> dict[int, dict[str, int]]: | |||
| """Return Leiden root communities.""" | |||
| results: dict[int, dict[str, int]] = {} | |||
| if is_empty(graph): return results | |||
| if is_empty(graph): | |||
| return results | |||
| if use_lcc: | |||
| graph = stable_largest_connected_component(graph) | |||
| @@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: | |||
| logging.debug( | |||
| "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc | |||
| ) | |||
| if not graph.nodes(): return {} | |||
| if not graph.nodes(): | |||
| return {} | |||
| node_id_to_community_map = _compute_leiden_communities( | |||
| graph=graph, | |||
| @@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: | |||
| result[community_id]["nodes"].append(node_id) | |||
| result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1) | |||
| weights = [comm["weight"] for _, comm in result.items()] | |||
| if not weights:continue | |||
| if not weights: | |||
| continue | |||
| max_weight = max(weights) | |||
| for _, comm in result.items(): comm["weight"] /= max_weight | |||
| for _, comm in result.items(): | |||
| comm["weight"] /= max_weight | |||
| return results_by_level | |||
| @@ -1 +1,5 @@ | |||
| from .ragflow_chat import * | |||
| from .ragflow_chat import RAGFlowChat | |||
| __all__ = [ | |||
| "RAGFlowChat" | |||
| ] | |||
| @@ -2,7 +2,6 @@ import logging | |||
| import requests | |||
| from bridge.context import ContextType # Import Context, ContextType | |||
| from bridge.reply import Reply, ReplyType # Import Reply, ReplyType | |||
| from bridge import * | |||
| from plugins import Plugin, register # Import Plugin and register | |||
| from plugins.event import Event, EventContext, EventAction # Import event-related classes | |||
| @@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = get_text(filename, binary) | |||
| sections = txt.split("\n") | |||
| sections = [(l, "") for l in sections if l] | |||
| sections = [(line, "") for line in sections if line] | |||
| remove_contents_table(sections, eng=is_english( | |||
| random_choices([t for t, _ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| @@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| sections = HtmlParser()(filename, binary) | |||
| sections = [(l, "") for l in sections if l] | |||
| sections = [(line, "") for line in sections if line] | |||
| remove_contents_table(sections, eng=is_english( | |||
| random_choices([t for t, _ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| @@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| binary = BytesIO(binary) | |||
| doc_parsed = parser.from_buffer(binary) | |||
| sections = doc_parsed['content'].split('\n') | |||
| sections = [(l, "") for l in sections if l] | |||
| sections = [(line, "") for line in sections if line] | |||
| remove_contents_table(sections, eng=is_english( | |||
| random_choices([t for t, _ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| @@ -75,7 +75,7 @@ def chunk( | |||
| _add_content(msg, msg.get_content_type()) | |||
| sections = TxtParser.parser_txt("\n".join(text_txt)) + [ | |||
| (l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l | |||
| (line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line | |||
| ] | |||
| st = timer() | |||
| @@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, | |||
| chunks = build_knowledge_graph_chunks(tenant_id, sections, callback, | |||
| parser_config.get("entity_types", ["organization", "person", "location", "event", "time"]) | |||
| ) | |||
| for c in chunks: c["docnm_kwd"] = filename | |||
| for c in chunks: | |||
| c["docnm_kwd"] = filename | |||
| doc = { | |||
| "docnm_kwd": filename, | |||
| @@ -48,7 +48,7 @@ class Docx(DocxParser): | |||
| continue | |||
| if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: | |||
| pn += 1 | |||
| return [l for l in lines if l] | |||
| return [line for line in lines if line] | |||
| def __call__(self, filename, binary=None, from_page=0, to_page=100000): | |||
| self.doc = Document( | |||
| @@ -60,7 +60,8 @@ class Docx(DocxParser): | |||
| if pn > to_page: | |||
| break | |||
| question_level, p_text = docx_question_level(p, bull) | |||
| if not p_text.strip("\n"):continue | |||
| if not p_text.strip("\n"): | |||
| continue | |||
| lines.append((question_level, p_text)) | |||
| for run in p.runs: | |||
| @@ -78,19 +79,21 @@ class Docx(DocxParser): | |||
| if lines[e][0] <= lines[s][0]: | |||
| break | |||
| e += 1 | |||
| if e - s == 1 and visit[s]: continue | |||
| if e - s == 1 and visit[s]: | |||
| continue | |||
| sec = [] | |||
| next_level = lines[s][0] + 1 | |||
| while not sec and next_level < 22: | |||
| for i in range(s+1, e): | |||
| if lines[i][0] != next_level: continue | |||
| if lines[i][0] != next_level: | |||
| continue | |||
| sec.append(lines[i][1]) | |||
| visit[i] = True | |||
| next_level += 1 | |||
| sec.insert(0, lines[s][1]) | |||
| sections.append("\n".join(sec)) | |||
| return [l for l in sections if l] | |||
| return [s for s in sections if s] | |||
| def __str__(self) -> str: | |||
| return f''' | |||
| @@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = get_text(filename, binary) | |||
| sections = txt.split("\n") | |||
| sections = [l for l in sections if l] | |||
| sections = [s for s in sections if s] | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| sections = HtmlParser()(filename, binary) | |||
| sections = [l for l in sections if l] | |||
| sections = [s for s in sections if s] | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.doc$", filename, re.IGNORECASE): | |||
| @@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| binary = BytesIO(binary) | |||
| doc_parsed = parser.from_buffer(binary) | |||
| sections = doc_parsed['content'].split('\n') | |||
| sections = [l for l in sections if l] | |||
| sections = [s for s in sections if s] | |||
| callback(0.8, "Finish parsing.") | |||
| else: | |||
| @@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| if sections and len(sections[0]) < 3: | |||
| sections = [(t, l, [[0] * 5]) for t, l in sections] | |||
| sections = [(t, lvl, [[0] * 5]) for t, lvl in sections] | |||
| # set pivot using the most frequent type of title, | |||
| # then merge between 2 pivot | |||
| if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1: | |||
| @@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| else: | |||
| bull = bullets_category([txt for txt, _, _ in sections]) | |||
| most_level, levels = title_frequency( | |||
| bull, [(txt, l) for txt, l, poss in sections]) | |||
| bull, [(txt, lvl) for txt, lvl, _ in sections]) | |||
| assert len(sections) == len(levels) | |||
| sec_ids = [] | |||
| @@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| sections = [(txt, sec_ids[i], poss) | |||
| for i, (txt, _, poss) in enumerate(sections)] | |||
| for (img, rows), poss in tbls: | |||
| if not rows: continue | |||
| if not rows: | |||
| continue | |||
| sections.append((rows if isinstance(rows, str) else rows[0], -1, | |||
| [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) | |||
| @@ -54,7 +54,8 @@ class Pdf(PdfParser): | |||
| sections = [(b["text"], self.get_position(b, zoomin)) | |||
| for i, b in enumerate(self.boxes)] | |||
| for (img, rows), poss in tbls: | |||
| if not rows:continue | |||
| if not rows: | |||
| continue | |||
| sections.append((rows if isinstance(rows, str) else rows[0], | |||
| [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) | |||
| return [(txt, "") for txt, _ in sorted(sections, key=lambda x: ( | |||
| @@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| binary = BytesIO(binary) | |||
| doc_parsed = parser.from_buffer(binary) | |||
| sections = doc_parsed['content'].split('\n') | |||
| sections = [l for l in sections if l] | |||
| sections = [s for s in sections if s] | |||
| callback(0.8, "Finish parsing.") | |||
| else: | |||
| @@ -171,7 +171,7 @@ class Pdf(PdfParser): | |||
| tbl_bottom = tbls[tbl_index][1][0][4] | |||
| tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ | |||
| .format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom) | |||
| tbl_text = ''.join(tbls[tbl_index][0][1]) | |||
| _tbl_text = ''.join(tbls[tbl_index][0][1]) | |||
| return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, | |||
| @@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): | |||
| txt = get_text(filename, binary) | |||
| lines = txt.split("\n") | |||
| comma, tab = 0, 0 | |||
| for l in lines: | |||
| if len(l.split(",")) == 2: comma += 1 | |||
| if len(l.split("\t")) == 2: tab += 1 | |||
| for line in lines: | |||
| if len(line.split(",")) == 2: | |||
| comma += 1 | |||
| if len(line.split("\t")) == 2: | |||
| tab += 1 | |||
| delimiter = "\t" if tab >= comma else "," | |||
| fails = [] | |||
| @@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): | |||
| while i < len(lines): | |||
| arr = lines[i].split(delimiter) | |||
| if len(arr) != 2: | |||
| if question: answer += "\n" + lines[i] | |||
| if question: | |||
| answer += "\n" + lines[i] | |||
| else: | |||
| fails.append(str(i+1)) | |||
| elif len(arr) == 2: | |||
| if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng)) | |||
| if question and answer: | |||
| res.append(beAdoc(deepcopy(doc), question, answer, eng)) | |||
| question, answer = arr | |||
| i += 1 | |||
| if len(res) % 999 == 0: | |||
| callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| if question: res.append(beAdoc(deepcopy(doc), question, answer, eng)) | |||
| if question: | |||
| res.append(beAdoc(deepcopy(doc), question, answer, eng)) | |||
| callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| @@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): | |||
| callback(0.1, "Start to parse.") | |||
| txt = get_text(filename, binary) | |||
| lines = txt.split("\n") | |||
| last_question, last_answer = "", "" | |||
| _last_question, last_answer = "", "" | |||
| question_stack, level_stack = [], [] | |||
| code_block = False | |||
| level_index = [-1] * 7 | |||
| for index, l in enumerate(lines): | |||
| if l.strip().startswith('```'): | |||
| for index, line in enumerate(lines): | |||
| if line.strip().startswith('```'): | |||
| code_block = not code_block | |||
| question_level, question = 0, '' | |||
| if not code_block: | |||
| question_level, question = mdQuestionLevel(l) | |||
| question_level, question = mdQuestionLevel(line) | |||
| if not question_level or question_level > 6: # not a question | |||
| last_answer = f'{last_answer}\n{l}' | |||
| last_answer = f'{last_answer}\n{line}' | |||
| else: # is a question | |||
| if last_answer.strip(): | |||
| sum_question = '\n'.join(question_stack) | |||
| @@ -41,14 +41,16 @@ class Excel(ExcelParser): | |||
| for sheetname in wb.sheetnames: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| if not rows:continue | |||
| if not rows: | |||
| continue | |||
| headers = [cell.value for cell in rows[0]] | |||
| missed = set([i for i, h in enumerate(headers) if h is None]) | |||
| headers = [ | |||
| cell.value for i, | |||
| cell in enumerate( | |||
| rows[0]) if i not in missed] | |||
| if not headers:continue | |||
| if not headers: | |||
| continue | |||
| data = [] | |||
| for i, r in enumerate(rows[1:]): | |||
| rn += 1 | |||
| @@ -88,7 +90,6 @@ def trans_bool(s): | |||
| def column_data_type(arr): | |||
| arr = list(arr) | |||
| uni = len(set([a for a in arr if a is not None])) | |||
| counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} | |||
| trans = {t: f for f, t in | |||
| [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} | |||
| @@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, | |||
| continue | |||
| if i >= to_page: | |||
| break | |||
| row = [l for l in line.split(kwargs.get("delimiter", "\t"))] | |||
| row = [field for field in line.split(kwargs.get("delimiter", "\t"))] | |||
| if len(row) != len(headers): | |||
| fails.append(str(i)) | |||
| continue | |||
| @@ -13,12 +13,124 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from .embedding_model import * | |||
| from .chat_model import * | |||
| from .cv_model import * | |||
| from .rerank_model import * | |||
| from .sequence2txt_model import * | |||
| from .tts_model import * | |||
| from .embedding_model import ( | |||
| OllamaEmbed, | |||
| LocalAIEmbed, | |||
| OpenAIEmbed, | |||
| AzureEmbed, | |||
| XinferenceEmbed, | |||
| QWenEmbed, | |||
| ZhipuEmbed, | |||
| FastEmbed, | |||
| YoudaoEmbed, | |||
| BaiChuanEmbed, | |||
| JinaEmbed, | |||
| DefaultEmbedding, | |||
| MistralEmbed, | |||
| BedrockEmbed, | |||
| GeminiEmbed, | |||
| NvidiaEmbed, | |||
| LmStudioEmbed, | |||
| OpenAI_APIEmbed, | |||
| CoHereEmbed, | |||
| TogetherAIEmbed, | |||
| PerfXCloudEmbed, | |||
| UpstageEmbed, | |||
| SILICONFLOWEmbed, | |||
| ReplicateEmbed, | |||
| BaiduYiyanEmbed, | |||
| VoyageEmbed, | |||
| HuggingFaceEmbed, | |||
| VolcEngineEmbed, | |||
| ) | |||
| from .chat_model import ( | |||
| GptTurbo, | |||
| AzureChat, | |||
| ZhipuChat, | |||
| QWenChat, | |||
| OllamaChat, | |||
| LocalAIChat, | |||
| XinferenceChat, | |||
| MoonshotChat, | |||
| DeepSeekChat, | |||
| VolcEngineChat, | |||
| BaiChuanChat, | |||
| MiniMaxChat, | |||
| MistralChat, | |||
| GeminiChat, | |||
| BedrockChat, | |||
| GroqChat, | |||
| OpenRouterChat, | |||
| StepFunChat, | |||
| NvidiaChat, | |||
| LmStudioChat, | |||
| OpenAI_APIChat, | |||
| CoHereChat, | |||
| LeptonAIChat, | |||
| TogetherAIChat, | |||
| PerfXCloudChat, | |||
| UpstageChat, | |||
| NovitaAIChat, | |||
| SILICONFLOWChat, | |||
| YiChat, | |||
| ReplicateChat, | |||
| HunyuanChat, | |||
| SparkChat, | |||
| BaiduYiyanChat, | |||
| AnthropicChat, | |||
| GoogleChat, | |||
| HuggingFaceChat, | |||
| ) | |||
| from .cv_model import ( | |||
| GptV4, | |||
| AzureGptV4, | |||
| OllamaCV, | |||
| XinferenceCV, | |||
| QWenCV, | |||
| Zhipu4V, | |||
| LocalCV, | |||
| GeminiCV, | |||
| OpenRouterCV, | |||
| LocalAICV, | |||
| NvidiaCV, | |||
| LmStudioCV, | |||
| StepFunCV, | |||
| OpenAI_APICV, | |||
| TogetherAICV, | |||
| YiCV, | |||
| HunyuanCV, | |||
| ) | |||
| from .rerank_model import ( | |||
| LocalAIRerank, | |||
| DefaultRerank, | |||
| JinaRerank, | |||
| YoudaoRerank, | |||
| XInferenceRerank, | |||
| NvidiaRerank, | |||
| LmStudioRerank, | |||
| OpenAI_APIRerank, | |||
| CoHereRerank, | |||
| TogetherAIRerank, | |||
| SILICONFLOWRerank, | |||
| BaiduYiyanRerank, | |||
| VoyageRerank, | |||
| QWenRerank, | |||
| ) | |||
| from .sequence2txt_model import ( | |||
| GPTSeq2txt, | |||
| QWenSeq2txt, | |||
| AzureSeq2txt, | |||
| XinferenceSeq2txt, | |||
| TencentCloudSeq2txt, | |||
| ) | |||
| from .tts_model import ( | |||
| FishAudioTTS, | |||
| QwenTTS, | |||
| OpenAITTS, | |||
| SparkTTS, | |||
| XinferenceTTS, | |||
| ) | |||
| EmbeddingModel = { | |||
| "Ollama": OllamaEmbed, | |||
| @@ -48,7 +160,7 @@ EmbeddingModel = { | |||
| "BaiduYiyan": BaiduYiyanEmbed, | |||
| "Voyage AI": VoyageEmbed, | |||
| "HuggingFace": HuggingFaceEmbed, | |||
| "VolcEngine":VolcEngineEmbed, | |||
| "VolcEngine": VolcEngineEmbed, | |||
| } | |||
| CvModel = { | |||
| @@ -68,7 +180,7 @@ CvModel = { | |||
| "OpenAI-API-Compatible": OpenAI_APICV, | |||
| "TogetherAI": TogetherAICV, | |||
| "01.AI": YiCV, | |||
| "Tencent Hunyuan": HunyuanCV | |||
| "Tencent Hunyuan": HunyuanCV, | |||
| } | |||
| ChatModel = { | |||
| @@ -111,7 +223,7 @@ ChatModel = { | |||
| } | |||
| RerankModel = { | |||
| "LocalAI":LocalAIRerank, | |||
| "LocalAI": LocalAIRerank, | |||
| "BAAI": DefaultRerank, | |||
| "Jina": JinaRerank, | |||
| "Youdao": YoudaoRerank, | |||
| @@ -132,7 +244,7 @@ Seq2txtModel = { | |||
| "Tongyi-Qianwen": QWenSeq2txt, | |||
| "Azure-OpenAI": AzureSeq2txt, | |||
| "Xinference": XinferenceSeq2txt, | |||
| "Tencent Cloud": TencentCloudSeq2txt | |||
| "Tencent Cloud": TencentCloudSeq2txt, | |||
| } | |||
| TTSModel = { | |||
| @@ -69,7 +69,8 @@ class Base(ABC): | |||
| stream=True, | |||
| **gen_conf) | |||
| for resp in response: | |||
| if not resp.choices: continue | |||
| if not resp.choices: | |||
| continue | |||
| if not resp.choices[0].delta.content: | |||
| resp.choices[0].delta.content = "" | |||
| ans += resp.choices[0].delta.content | |||
| @@ -81,7 +82,8 @@ class Base(ABC): | |||
| ) | |||
| elif isinstance(resp.usage, dict): | |||
| total_tokens = resp.usage.get("total_tokens", total_tokens) | |||
| else: total_tokens = resp.usage.total_tokens | |||
| else: | |||
| total_tokens = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "length": | |||
| if is_chinese(ans): | |||
| @@ -98,13 +100,15 @@ class Base(ABC): | |||
| class GptTurbo(Base): | |||
| def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): | |||
| if not base_url: base_url = "https://api.openai.com/v1" | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| super().__init__(key, model_name, base_url) | |||
| class MoonshotChat(Base): | |||
| def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): | |||
| if not base_url: base_url = "https://api.moonshot.cn/v1" | |||
| if not base_url: | |||
| base_url = "https://api.moonshot.cn/v1" | |||
| super().__init__(key, model_name, base_url) | |||
| @@ -128,7 +132,8 @@ class HuggingFaceChat(Base): | |||
| class DeepSeekChat(Base): | |||
| def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): | |||
| if not base_url: base_url = "https://api.deepseek.com/v1" | |||
| if not base_url: | |||
| base_url = "https://api.deepseek.com/v1" | |||
| super().__init__(key, model_name, base_url) | |||
| @@ -202,7 +207,8 @@ class BaiChuanChat(Base): | |||
| stream=True, | |||
| **self._format_params(gen_conf)) | |||
| for resp in response: | |||
| if not resp.choices: continue | |||
| if not resp.choices: | |||
| continue | |||
| if not resp.choices[0].delta.content: | |||
| resp.choices[0].delta.content = "" | |||
| ans += resp.choices[0].delta.content | |||
| @@ -313,8 +319,10 @@ class ZhipuChat(Base): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=history, | |||
| @@ -333,8 +341,10 @@ class ZhipuChat(Base): | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| tk_count = 0 | |||
| try: | |||
| @@ -345,7 +355,8 @@ class ZhipuChat(Base): | |||
| **gen_conf | |||
| ) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content: continue | |||
| if not resp.choices[0].delta.content: | |||
| continue | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| if resp.choices[0].finish_reason == "length": | |||
| @@ -354,7 +365,8 @@ class ZhipuChat(Base): | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -372,11 +384,16 @@ class OllamaChat(Base): | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| options = {} | |||
| if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| if "temperature" in gen_conf: | |||
| options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: | |||
| options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: | |||
| options["top_p"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: | |||
| options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| response = self.client.chat( | |||
| model=self.model_name, | |||
| messages=history, | |||
| @@ -392,11 +409,16 @@ class OllamaChat(Base): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| options = {} | |||
| if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| if "temperature" in gen_conf: | |||
| options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: | |||
| options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: | |||
| options["top_p"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: | |||
| options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| try: | |||
| response = self.client.chat( | |||
| @@ -636,7 +658,8 @@ class MistralChat(Base): | |||
| messages=history, | |||
| **gen_conf) | |||
| for resp in response: | |||
| if not resp.choices or not resp.choices[0].delta.content: continue | |||
| if not resp.choices or not resp.choices[0].delta.content: | |||
| continue | |||
| ans += resp.choices[0].delta.content | |||
| total_tokens += 1 | |||
| if resp.choices[0].finish_reason == "length": | |||
| @@ -1196,7 +1219,8 @@ class SparkChat(Base): | |||
| assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}" | |||
| if model_name in model2version: | |||
| model_version = model2version[model_name] | |||
| else: model_version = model_name | |||
| else: | |||
| model_version = model_name | |||
| super().__init__(key, model_version, base_url) | |||
| @@ -1281,8 +1305,10 @@ class AnthropicChat(Base): | |||
| self.system = system | |||
| if "max_tokens" not in gen_conf: | |||
| gen_conf["max_tokens"] = 4096 | |||
| if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| try: | |||
| @@ -1312,8 +1338,10 @@ class AnthropicChat(Base): | |||
| self.system = system | |||
| if "max_tokens" not in gen_conf: | |||
| gen_conf["max_tokens"] = 4096 | |||
| if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| total_tokens = 0 | |||
| @@ -25,6 +25,7 @@ import base64 | |||
| from io import BytesIO | |||
| import json | |||
| import requests | |||
| from transformers import GenerationConfig | |||
| from rag.nlp import is_english | |||
| from api.utils import get_uuid | |||
| @@ -77,14 +78,16 @@ class Base(ABC): | |||
| stream=True | |||
| ) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content: continue | |||
| if not resp.choices[0].delta.content: | |||
| continue | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -99,7 +102,7 @@ class Base(ABC): | |||
| buffered = BytesIO() | |||
| try: | |||
| image.save(buffered, format="JPEG") | |||
| except Exception as e: | |||
| except Exception: | |||
| image.save(buffered, format="PNG") | |||
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| @@ -139,7 +142,8 @@ class Base(ABC): | |||
| class GptV4(Base): | |||
| def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): | |||
| if not base_url: base_url="https://api.openai.com/v1" | |||
| if not base_url: | |||
| base_url="https://api.openai.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -149,7 +153,8 @@ class GptV4(Base): | |||
| prompt = self.prompt(b64) | |||
| for i in range(len(prompt)): | |||
| for c in prompt[i]["content"]: | |||
| if "text" in c: c["type"] = "text" | |||
| if "text" in c: | |||
| c["type"] = "text" | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| @@ -171,7 +176,8 @@ class AzureGptV4(Base): | |||
| prompt = self.prompt(b64) | |||
| for i in range(len(prompt)): | |||
| for c in prompt[i]["content"]: | |||
| if "text" in c: c["type"] = "text" | |||
| if "text" in c: | |||
| c["type"] = "text" | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| @@ -344,14 +350,16 @@ class Zhipu4V(Base): | |||
| stream=True | |||
| ) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content: continue | |||
| if not resp.choices[0].delta.content: | |||
| continue | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -389,11 +397,16 @@ class OllamaCV(Base): | |||
| if his["role"] == "user": | |||
| his["images"] = [image] | |||
| options = {} | |||
| if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| if "temperature" in gen_conf: | |||
| options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: | |||
| options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: | |||
| options["top_k"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: | |||
| options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| response = self.client.chat( | |||
| model=self.model_name, | |||
| messages=history, | |||
| @@ -414,11 +427,16 @@ class OllamaCV(Base): | |||
| if his["role"] == "user": | |||
| his["images"] = [image] | |||
| options = {} | |||
| if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| if "temperature" in gen_conf: | |||
| options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: | |||
| options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: | |||
| options["top_k"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: | |||
| options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| try: | |||
| response = self.client.chat( | |||
| @@ -469,7 +487,7 @@ class XinferenceCV(Base): | |||
| class GeminiCV(Base): | |||
| def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): | |||
| from google.generativeai import client, GenerativeModel, GenerationConfig | |||
| from google.generativeai import client, GenerativeModel | |||
| client.configure(api_key=key) | |||
| _client = client.get_default_generative_client() | |||
| self.model_name = model_name | |||
| @@ -503,7 +521,7 @@ class GeminiCV(Base): | |||
| if his["role"] == "user": | |||
| his["parts"] = [his["content"]] | |||
| his.pop("content") | |||
| history[-1]["parts"].append(f"data:image/jpeg;base64," + image) | |||
| history[-1]["parts"].append("data:image/jpeg;base64," + image) | |||
| response = self.model.generate_content(history, generation_config=GenerationConfig( | |||
| max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), | |||
| @@ -519,7 +537,6 @@ class GeminiCV(Base): | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| ans = "" | |||
| tk_count = 0 | |||
| try: | |||
| for his in history: | |||
| if his["role"] == "assistant": | |||
| @@ -529,14 +546,15 @@ class GeminiCV(Base): | |||
| if his["role"] == "user": | |||
| his["parts"] = [his["content"]] | |||
| his.pop("content") | |||
| history[-1]["parts"].append(f"data:image/jpeg;base64," + image) | |||
| history[-1]["parts"].append("data:image/jpeg;base64," + image) | |||
| response = self.model.generate_content(history, generation_config=GenerationConfig( | |||
| max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7)), stream=True) | |||
| for resp in response: | |||
| if not resp.text: continue | |||
| if not resp.text: | |||
| continue | |||
| ans += resp.text | |||
| yield ans | |||
| except Exception as e: | |||
| @@ -632,7 +650,8 @@ class NvidiaCV(Base): | |||
| class StepFunCV(GptV4): | |||
| def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): | |||
| if not base_url: base_url="https://api.stepfun.com/v1" | |||
| if not base_url: | |||
| base_url="https://api.stepfun.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -15,12 +15,9 @@ | |||
| # | |||
| import requests | |||
| from openai.lib.azure import AzureOpenAI | |||
| from zhipuai import ZhipuAI | |||
| import io | |||
| from abc import ABC | |||
| from ollama import Client | |||
| from openai import OpenAI | |||
| import os | |||
| import json | |||
| from rag.utils import num_tokens_from_string | |||
| import base64 | |||
| @@ -49,7 +46,8 @@ class Base(ABC): | |||
| class GPTSeq2txt(Base): | |||
| def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): | |||
| if not base_url: base_url = "https://api.openai.com/v1" | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| @@ -16,7 +16,6 @@ | |||
| import _thread as thread | |||
| import base64 | |||
| import datetime | |||
| import hashlib | |||
| import hmac | |||
| import json | |||
| @@ -175,7 +174,8 @@ class QwenTTS(Base): | |||
| class OpenAITTS(Base): | |||
| def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): | |||
| if not base_url: base_url = "https://api.openai.com/v1" | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| self.api_key = key | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| @@ -222,7 +222,8 @@ def bullets_category(sections): | |||
| def is_english(texts): | |||
| eng = 0 | |||
| if not texts: return False | |||
| if not texts: | |||
| return False | |||
| for t in texts: | |||
| if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()): | |||
| eng += 1 | |||
| @@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None): | |||
| res = [] | |||
| # wrap up as es documents | |||
| for ck in chunks: | |||
| if len(ck.strip()) == 0:continue | |||
| if len(ck.strip()) == 0: | |||
| continue | |||
| logging.debug("-- {}".format(ck)) | |||
| d = copy.deepcopy(doc) | |||
| if pdf_parser: | |||
| @@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images): | |||
| res = [] | |||
| # wrap up as es documents | |||
| for ck, image in zip(chunks, images): | |||
| if len(ck.strip()) == 0:continue | |||
| if len(ck.strip()) == 0: | |||
| continue | |||
| logging.debug("-- {}".format(ck)) | |||
| d = copy.deepcopy(doc) | |||
| d["image"] = image | |||
| @@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10): | |||
| d = copy.deepcopy(doc) | |||
| tokenize(d, rows, eng) | |||
| d["content_with_weight"] = rows | |||
| if img: d["image"] = img | |||
| if poss: add_positions(d, poss) | |||
| if img: | |||
| d["image"] = img | |||
| if poss: | |||
| add_positions(d, poss) | |||
| res.append(d) | |||
| continue | |||
| de = "; " if eng else "; " | |||
| @@ -387,9 +392,9 @@ def title_frequency(bull, sections): | |||
| if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): | |||
| levels[i] = bullets_size | |||
| most_level = bullets_size+1 | |||
| for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): | |||
| if l <= bullets_size: | |||
| most_level = l | |||
| for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): | |||
| if level <= bullets_size: | |||
| most_level = level | |||
| break | |||
| return most_level, levels | |||
| @@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): | |||
| def add_chunk(t, pos): | |||
| nonlocal cks, tk_nums, delimiter | |||
| tnum = num_tokens_from_string(t) | |||
| if not pos: pos = "" | |||
| if not pos: | |||
| pos = "" | |||
| if tnum < 8: | |||
| pos = "" | |||
| # Ensure that the length of the merged chunk does not exceed chunk_token_num | |||
| @@ -121,7 +121,8 @@ class FulltextQueryer: | |||
| keywords.append(tt) | |||
| twts = self.tw.weights([tt]) | |||
| syns = self.syn.lookup(tt) | |||
| if syns and len(keywords) < 32: keywords.extend(syns) | |||
| if syns and len(keywords) < 32: | |||
| keywords.extend(syns) | |||
| logging.debug(json.dumps(twts, ensure_ascii=False)) | |||
| tms = [] | |||
| for tk, w in sorted(twts, key=lambda x: x[1] * -1): | |||
| @@ -147,7 +148,8 @@ class FulltextQueryer: | |||
| tk_syns = self.syn.lookup(tk) | |||
| tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] | |||
| if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) | |||
| if len(keywords) < 32: | |||
| keywords.extend([s for s in tk_syns if s]) | |||
| tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] | |||
| tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns] | |||
| @@ -104,7 +104,6 @@ class RagTokenizer: | |||
| return HanziConv.toSimplified(line) | |||
| def dfs_(self, chars, s, preTks, tkslist): | |||
| MAX_L = 10 | |||
| res = s | |||
| # if s > MAX_L or s>= len(chars): | |||
| if s >= len(chars): | |||
| @@ -184,12 +183,6 @@ class RagTokenizer: | |||
| return sorted(res, key=lambda x: x[1], reverse=True) | |||
| def merge_(self, tks): | |||
| patts = [ | |||
| (r"[ ]+", " "), | |||
| (r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"), | |||
| ] | |||
| # for p,s in patts: tks = re.sub(p, s, tks) | |||
| # if split chars is part of token | |||
| res = [] | |||
| tks = re.sub(r"[ ]+", " ", tks).split() | |||
| @@ -284,7 +277,8 @@ class RagTokenizer: | |||
| same = 0 | |||
| while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: | |||
| same += 1 | |||
| if same > 0: res.append(" ".join(tks[j: j + same])) | |||
| if same > 0: | |||
| res.append(" ".join(tks[j: j + same])) | |||
| _i = i + same | |||
| _j = j + same | |||
| j = _j + 1 | |||
| @@ -62,10 +62,10 @@ class Dealer: | |||
| res = {} | |||
| f = open(fnm, "r") | |||
| while True: | |||
| l = f.readline() | |||
| if not l: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| arr = l.replace("\n", "").split("\t") | |||
| arr = line.replace("\n", "").split("\t") | |||
| if len(arr) < 2: | |||
| res[arr[0]] = 0 | |||
| else: | |||
| @@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| def __call__(self, chunks, random_state, callback=None): | |||
| layers = [(0, len(chunks))] | |||
| start, end = 0, len(chunks) | |||
| if len(chunks) <= 1: return | |||
| if len(chunks) <= 1: | |||
| return | |||
| chunks = [(s, a) for s, a in chunks if len(a) > 0] | |||
| def summarize(ck_idx, lock): | |||
| @@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| logging.debug(f"SUM: {cnt}") | |||
| embds, _ = self._embd_model.encode([cnt]) | |||
| with lock: | |||
| if not len(embds[0]): return | |||
| if not len(embds[0]): | |||
| return | |||
| chunks.append((cnt, embds[0])) | |||
| except Exception as e: | |||
| logging.exception("summarize got exception") | |||
| @@ -33,14 +33,16 @@ def collect(): | |||
| def main(): | |||
| locations = collect() | |||
| if not locations:return | |||
| if not locations: | |||
| return | |||
| logging.info(f"TASKS: {len(locations)}") | |||
| for kb_id, loc in locations: | |||
| try: | |||
| if REDIS_CONN.is_alive(): | |||
| try: | |||
| key = "{}/{}".format(kb_id, loc) | |||
| if REDIS_CONN.exist(key):continue | |||
| if REDIS_CONN.exist(key): | |||
| continue | |||
| file_bin = STORAGE_IMPL.get(kb_id, loc) | |||
| REDIS_CONN.transaction(key, file_bin, 12 * 60) | |||
| logging.info("CACHE: {}".format(loc)) | |||
| @@ -23,18 +23,12 @@ import os | |||
| from api.utils.log_utils import initRootLogger | |||
| CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | |||
| CONSUMER_NAME = "task_executor_" + CONSUMER_NO | |||
| LOG_LEVELS = os.environ.get("LOG_LEVELS", "") | |||
| initRootLogger(CONSUMER_NAME, LOG_LEVELS) | |||
| from datetime import datetime | |||
| import json | |||
| import os | |||
| import hashlib | |||
| import copy | |||
| import re | |||
| import sys | |||
| import time | |||
| import threading | |||
| from functools import partial | |||
| @@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string | |||
| from rag.utils.redis_conn import REDIS_CONN, Payload | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] | |||
| CONSUMER_NAME = "task_executor_" + CONSUMER_NO | |||
| LOG_LEVELS = os.environ.get("LOG_LEVELS", "") | |||
| initRootLogger(CONSUMER_NAME, LOG_LEVELS) | |||
| BATCH_SIZE = 64 | |||
| FACTORY = { | |||
| @@ -201,7 +200,8 @@ def build_chunks(task, progress_callback): | |||
| "doc_id": task["doc_id"], | |||
| "kb_id": str(task["kb_id"]) | |||
| } | |||
| if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"]) | |||
| if task["pagerank"]: | |||
| doc["pagerank_fea"] = int(task["pagerank"]) | |||
| el = 0 | |||
| for ck in cks: | |||
| d = copy.deepcopy(doc) | |||
| @@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): | |||
| "docnm_kwd": row["name"], | |||
| "title_tks": rag_tokenizer.tokenize(row["name"]) | |||
| } | |||
| if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"]) | |||
| if row["pagerank"]: | |||
| doc["pagerank_fea"] = int(row["pagerank"]) | |||
| res = [] | |||
| tk_count = 0 | |||
| for content, vctr in chunks[original_length:]: | |||
| @@ -41,15 +41,15 @@ def findMaxDt(fnm): | |||
| try: | |||
| with open(fnm, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| l = l.strip("\n") | |||
| if l == 'nan': | |||
| line = line.strip("\n") | |||
| if line == 'nan': | |||
| continue | |||
| if l > m: | |||
| m = l | |||
| except Exception as e: | |||
| if line > m: | |||
| m = line | |||
| except Exception: | |||
| pass | |||
| return m | |||
| @@ -59,15 +59,15 @@ def findMaxTm(fnm): | |||
| try: | |||
| with open(fnm, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| l = l.strip("\n") | |||
| if l == 'nan': | |||
| line = line.strip("\n") | |||
| if line == 'nan': | |||
| continue | |||
| if int(l) > m: | |||
| m = int(l) | |||
| except Exception as e: | |||
| if int(line) > m: | |||
| m = int(line) | |||
| except Exception: | |||
| pass | |||
| return m | |||
| @@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object): | |||
| self.conn = None | |||
| def health(self): | |||
| bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" | |||
| _bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" | |||
| return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary)) | |||
| def put(self, bucket, fnm, binary): | |||
| @@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object): | |||
| self.conn = None | |||
| def health(self): | |||
| bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" | |||
| _bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" | |||
| f = self.conn.create_file(fnm) | |||
| f.append_data(binary, offset=0, length=len(binary)) | |||
| return f.flush_data(len(binary)) | |||
| @@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection): | |||
| bqry.filter.append( | |||
| Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||
| continue | |||
| if not v: continue | |||
| if not v: | |||
| continue | |||
| if isinstance(v, list): | |||
| bqry.filter.append(Q("terms", **{k: v})) | |||
| elif isinstance(v, str) or isinstance(v, int): | |||
| @@ -1,14 +1,21 @@ | |||
| from beartype.claw import beartype_this_package | |||
| beartype_this_package() # <-- raise exceptions in your code | |||
| import importlib.metadata | |||
| __version__ = importlib.metadata.version("ragflow_sdk") | |||
| from .ragflow import RAGFlow | |||
| from .modules.dataset import DataSet | |||
| from .modules.chat import Chat | |||
| from .modules.session import Session | |||
| from .modules.document import Document | |||
| from .modules.chunk import Chunk | |||
| from .modules.agent import Agent | |||
| from .modules.agent import Agent | |||
| __version__ = importlib.metadata.version("ragflow_sdk") | |||
| __all__ = [ | |||
| "RAGFlow", | |||
| "DataSet", | |||
| "Chat", | |||
| "Session", | |||
| "Document", | |||
| "Chunk", | |||
| "Agent" | |||
| ] | |||
| @@ -29,7 +29,7 @@ class Session(Base): | |||
| raise Exception(json_data["message"]) | |||
| if line.startswith("data:"): | |||
| json_data = json.loads(line[5:]) | |||
| if json_data["data"] != True: | |||
| if not json_data["data"]: | |||
| answer = json_data["data"]["answer"] | |||
| reference = json_data["data"]["reference"] | |||
| temp_dict = { | |||
| @@ -1,5 +1,3 @@ | |||
| import string | |||
| import random | |||
| import os | |||
| import pytest | |||
| import requests | |||
| @@ -39,7 +39,6 @@ def update_dataset(auth, json_req): | |||
| def upload_file(auth, dataset_id, path): | |||
| authorization = {"Authorization": auth} | |||
| url = f"{HOST_ADDRESS}/v1/document/upload" | |||
| base_name = os.path.basename(path) | |||
| json_req = { | |||
| "kb_id": dataset_id, | |||
| } | |||
| @@ -1,3 +1,3 @@ | |||
| def test_get_email(get_email): | |||
| print(f"\nEmail account:",flush=True) | |||
| print("\nEmail account:",flush=True) | |||
| print(f"{get_email}\n",flush=True) | |||
| @@ -13,14 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT | |||
| from common import create_dataset, list_dataset, rm_dataset, upload_file | |||
| from common import list_document, get_docs_info, parse_docs | |||
| from time import sleep | |||
| from timeit import default_timer as timer | |||
| import re | |||
| import pytest | |||
| import random | |||
| import string | |||
| def test_parse_txt_document(get_auth): | |||
| @@ -1,6 +1,5 @@ | |||
| from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT | |||
| from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT | |||
| import re | |||
| import pytest | |||
| import random | |||
| import string | |||
| @@ -33,8 +32,6 @@ def test_dataset(get_auth): | |||
| def test_dataset_1k_dataset(get_auth): | |||
| # create dataset | |||
| authorization = {"Authorization": get_auth} | |||
| url = f"{HOST_ADDRESS}/v1/kb/create" | |||
| for i in range(1000): | |||
| res = create_dataset(get_auth, f"test_create_dataset_{i}") | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| @@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth): | |||
| dataset_id = item.get("id") | |||
| dataset_list.append(dataset_id) | |||
| match = re.match(pattern, dataset_name) | |||
| assert match != None | |||
| assert match is not None | |||
| for dataset_id in dataset_list: | |||
| res = rm_dataset(get_auth, dataset_id) | |||
| @@ -1,3 +1,3 @@ | |||
| def test_get_email(get_email): | |||
| print(f"\nEmail account:",flush=True) | |||
| print("\nEmail account:",flush=True) | |||
| print(f"{get_email}\n",flush=True) | |||
| @@ -1,4 +1,4 @@ | |||
| from ragflow_sdk import RAGFlow,Agent | |||
| from ragflow_sdk import RAGFlow | |||
| from common import HOST_ADDRESS | |||
| import pytest | |||